wasmati: You should write your WebAssembly in TypeScript

written by
Gregor Mitscha-Baude
on

If you know about WebAssembly or Wasm, you might have heard a few different pitches for why you need it:

It’s a low level language for the web. So, it’s faster than JS. Use it to speed up things.

It’s a build target. As such, it enables you to use languages other than JS on the web. You could write a web library in Rust or Go!

You actually don’t use it on the web, but on the backend! It’s the perfect VM for untrusted code since it can only access the outside world through capabilities provided by the host runtime.

I sense that #2 and #3 are the main reasons people use Wasm, while #1 is the one that excites me personally. I found that Wasm can give you vast speed-ups over JS, especially when doing heavy cryptography stuff in the browser, which we at zkSecurity love to do. See real-world examples near the end of this post.

So, if you’re like me and want to use Wasm to speed up your JS, you will face the question of how to write that Wasm, and how to integrate it in your stack. I have been dancing around that question for two years now, both professionally and in side projects, and I finally came up with a solution that satisfies me.

The solution is wasmati, a TypeScript library to write Wasm at the instruction level. The wasmati API looks exactly like Webassembly text format (WAT). Here’s a taste:

import { i64, func, local } from"wasmati";
constmyMultiply=func({ in: [i64, i64], out: [i64] }, ([x, y]) => {
local.get(x); // put input x on the stack
local.get(y); // put input y on the stack
i64.mul(); // pop the two last values from the stack, multiply them, put the result on the stack
});

For reference, this would be an equivalent WAT code snippet:

Before you run away screaming at the idea of messing with low-level details like stack variables, I have to emphasize one thing: writing Wasm by hand is feasible because, if you use it for performance, you don’t need a lot of it. JS is fine 99% of the time, and to make things as fast as they need to be, you only need a sprinkling of Wasm in key places of your library (but there, you might need it badly).

In stark contrast to writing .wat files, using WAT-embedded-in-TS is something I find very enjoyable. You have the raw instructions at your fingertips, with no abstractions preventing you from writing the most optimized code. At the same time, having a high-level language to compose instructions unlocks compile-time programming, more on that later.

The rest of this post is my story of how I came to create wasmati, the trade-offs considered in designing it, and examples for how you can use it. In the end, we implement finite field addition (a performance-critical operation in a lot of cryptography) and benchmark our wasmati code against a pure JS implementation using BigInt.

Why I created wasmati

At my day job, I maintain SnarkyJS, a TypeScript framework that lets you create zero-knowledge proofs of arbitrary statements in a web browser. The heavy machinery for creating zk proofs is located in a Rust library. Rust is compiled to Wasm to bring that machinery to JS runtimes. As an aside, we also use a fair amount of OCaml compiled to JS.

The multi-language architecture of SnarkyJS is rooted in solid engineering choices. Rust is where all the crypto talent is, and it’s the dominating language in the crypto ecosystem I’m part of. You can reuse a Rust codebase across native and JS runtimes. Nonetheless, as a developer responsible for the TypeScript side of things, I developed a dislike of the impendance mismatch, bad debugging experience and general complexity such an architecture creates. I yearn for the simplicity of having everything in one language, of running just one build step, maintaining just one dependencies file, you get the idea. When I ctrl-click on a function signature, I want to see the source code of that function, not a .d.ts file which stands in place for an opaque blob of compiled Wasm. I don’t want to context-switch into a Rust codebase, where I first have to wade through an ugly layer of wasm-bindgen glue until I get to actual implementations.

It’s fun to compare the different approaches taken by ZPrize participants:

The winning team, Yrrid led by Niall Emmart, wrote their code from scratch in C and compiled it to Wasm with clang. They auto-generated the C code for finite field arithmetic using Java.

My submission used mostly JS, plus raw Wasm generated using a hacked-together, badly-typed JS lib for dumping WAT syntax into a string.

The third-placing team used a fork of wasmcurves by Jordi Baylina, which generates Wasm using a JS package called wasmbuilder. This is similar in spirit to wasmati, without the API polish and types.

The 4th-placing team used Go (gnark) compiled to Wasm

The 5th-placing team used Rust (arkworks) compiled to Wasm

The 6th-placing team used C (MIRACL Core) compiled to Wasm

The reference implementation is Rust (arkworks)

It might be a coincidence, but it’s funny how the best solutions are all somewhat hacky compared to “cleaner” solutions like compiling Wasm from Go or Rust and building on an off-the-shelf library.

Another takeaway is that for most of the code, it doesn’t hurt performance to use JS instead of Wasm. My submission was the only one that did so, using Wasm only for low-level arithmetic while the others were 100% Wasm.

A scalable way to write raw WebAssembly

When I started working on the ZPrize, I knew that I wanted to use mostly JS with some sprinkles of Wasm. I had already immersed myself in writing Wasm directly for a side project the year before, where I started developing a fully-fledged bundler for WAT, called watever. My thesis at the time was based on WAT as the dev environment.

After I started experimenting with different variants of finite field multiplication for the ZPrize, it became obvious that I had been wrong: the WAT approach didn’t scale. Why? Well, finite field algorithms consist of hundreds of instructions, with dozens of baked-in constants, each of which has to change if you’re changing some parameter of your algorithm. It was clear that writing it all by hand would be madness. What I needed was the ability to use meta-programming or compile-time programming: I needed code to generate all those instructions for me.

Thus, the idea of a TS package to write Wasm was born. For the prize, I accomplished this the quick-and-dirty way, but a few months later I decided to redo it, this time properly: A well-typed, well-maintained library, supporting the full WebAssembly spec. It would give developers a clean API mirroring WAT syntax, and create Wasm bytecode without a compile step. The Wasm encoding would be implemented from scratch in TS using parser-printer combinator patterns without heavy dependencies like wabt or binaryen.

I’m proud to announce that this library, called wasmati, is ready to be used! It went through enough dogfooding cycles while I reimplemented all my ZPrize code with it.

A real-world example: finite field operations

To get a feel for wasmati, and understand the idea of compile-time programming, let’s do an end to end example: we’re going to implement some finite field operations. It doesn’t matter if you don’t know what a finite field is, because we’ll only do basic stuff and explain everything.

A field element, for our purposes, is one of the numbers $0, …, p-1$ where $p$ is a prime. In real-world applications, $p$ is usually too large to fit in 64 bits. In elliptic curve cryptography, it’s often about 256 bits. Therefore, in software, we have to represent a field element $x$ using an array of native integers $x[0],…,x[n-1]$. Each integer $x[i]$ is called a “word” or “limb” and is understood to be of a certain maximum number of bits, which we call $w$ (the “word length” or “limb size”). The size $n$ is fixed – it depends on $p$ and the limb size $w$ –, so field elements will be stored as static-size arrays of integers. Here’s the mathematical formula for how a field element x is represented by its limbs:

Operations on field elements will usually involve a loop over the n limbs.

As a first code example, let’s try to implement the operation $x > y$. We’ll be coding these in TypeScript with wasmati; Wasm instructions are intuitive enough that you can pick them up on the fly. If you want to code along, create a new npm project in an empty folder with npm init -y and install npm i wasmati typescript. Create a minimal tsconfig:

Also, change the package.json to be "type": "module" because we’re not barbarians.

Finally, create a file greater.ts with the content below. You can execute that file anytime with ts-node-esm greater.ts (assuming you have ts-node installed globally).

You can also find the full example code on our github.

wasmati basics

You declare Wasm functions with func. The func constructor takes an object describing the function signature, and a callback where you’ll write the function’s code. Our function takes two i32 as inputs – those are pointers to x and y in memory. The output is another i32, which is either 1 or 0 (representing true or false), depending on the whether x is greater or not. Wasm doesn’t have booleans.

The algorithm for x > y starts like this: If the highest limb of x is greater than the highest limb of y, then for sure x > y and we can return true.

The limbs in our implementation are all 32-bit integers, and the Wasm instruction to load an i32 is i32.load():

local.get(x); // push memory address x on the stack
i32.load({ offset: 0 }); // load an i32 from memory at address x, offset 0

Note that Wasm is a stack-based language: Each operation pops its arguments from an implicit stack and then pushes its result to that stack again. In this example, we’re first pushing x (the memory address) on the stack by calling local.get(x). Then, i32.load() pops x and pushes the i32 it loaded from memory.

The wasmati library lets you imperatively “call” Wasm instructions like local.get() and i32.load(). Calling instructions adds them to the instruction list of the Wasm function that is currently created.

Like many instructions, i32.load() has a parameter that determines its behaviour at compile time: It takes an optional offset (in bytes) which is added to the stack argument x before loading from that location.

To make instruction sequences more readable, WAT has a form of syntax sugar that lets you group them as S-expressions so that stack arguments of an instruction look like function arguments:

(i32.loadoffset=0 (local.get $x))
# is the same as
local.get $x
i32.loadoffset=0

We let you do a similar thing in wasmati. You can write:

i32.load({ offset: 0 }, local.get(x));

This works pretty simply: local.get(x) is executed before i32.load(), therefore it comes first in the instruction list. As a side-effect, this gives you type-safety. local.get() uses the type of x (Local<i32>) to infer that its return is of type StackVar<i32>; i32.load(), on the other hand, only accepts i32-typed inputs, so if you’d accidentally make x a different type, you would get a type error.

Since we need lots of local.get() when writing Wasm, wasmati gives you another bit of convenience: You can just pass in x directly, and local.get() will be added automatically. Again, this is type-safe.

i32.load({ offset: 0 }, x);

Remember that we wanted to load limbs of a field element from memory. The code above loads x[0], because that’s the first i32 we store at the memory location x. The next limb x[1] would be another i32 stored at an 4 byte offset (because 32 bits = 4 * 8 bits = 4 bytes):

i32.load({ offset: 4 }, x); // load x[1]

To load the final limb x[n-1], we need some pointer arithmetic. Let’s assume for the rest of this post that $n = 9$. Then x[n-1] lives in memory at a $4 (n-1) = 4 \cdot 8 = 32$ byte offset from the address x.

i32.load({ offset: 32 }, x); // load x[n-1] where n=9

At this point you probably protest and say: This is crazy, why do we hard-code those byte offsets? Let’s make a helper function to load the ith limb, so we never have to look at this pointer arithmetic again! You’re right, let’s do this:

At this point, we’re doing something that is impossible when writing a raw WAT file: We create an abstraction that helps us write instructions more easily, without affecting the Wasm runtime in any way.

Our first Wasm function

Armed with a convenient way to load x[i] from memory, let’s start coding isGreater(): If x[n-1] > y[n-1], return true.

import { i32, func, Local, if_, return_ } from"wasmati";
constn=9; // number of limbs
constisGreater=func({ in: [i32, i32], out: [i32] }, ([x, y]) => {
loadLimb(x, n-1); // put x[n-1] on the stack
loadLimb(y, n-1); // put y[n-1] on the stack
i32.gt_s(); // put (x[n-1] > y[n-1]) on the stack
// execute this if-block if (x[n-1] > y[n-1])
if_(null, () => {
i32.const(1); // put 1 (true) on the stack
return_(); // return early
});
// TODO: rest of the logic
});

I put comments on this code so you see what each line does. Note again the stack-based nature of Wasm: an if_ block expects that the condition (an i32) is already on the stack; it pops that i32 and executes the if branch if the i32 is non-zero.

When you call func(), wasmati will eagerly execute its callback and construct an object representing the Wasm function. If you run your file right now, you’ll see it throw an error:

Error: expected i32 on the stack, got nothing
...
at file:///<path>/greater.ts:5:19

The stack trace points us to the line where we define our function. That’s because isGreater has a return signature of [i32], but after the conditional if block we don’t return anything: the stack when reaching the end of the function is empty. To match the function signature, let’s put 0 (false) on the stack as the last line:

// TODO: rest of the logic
i32.const(0);
});

Moving on. So far, our function returns true if the highest x limb is greater than the highest y limb. Conversely, if the highest y limb is greater, then we know x < y and we can immediately return false. To be able to do a second operation on x[n-1] and y[n-1] without loading them from memory again, we introduce two locals where we store those values. Locals are declared in another argument to the func constructor, and they are also inputs to the callback:

We set those locals to what loadLimb() returns, using local.set(). Then, we pass them to i32.gt_s(), again avoiding to write out local.get() instructions:

constisGreater=func(
{ in: [i32, i32], locals: [i32, i32], out: [i32] },
([x, y], [xi, yi]) => {
// set xi = x[n-1] and yi = y[n-1]
local.set(xi, loadLimb(x, n-1));
local.set(yi, loadLimb(y, n-1));
i32.gt_s(xi, yi); // put (xi > yi) on the stack
// ... as before ...
}
);

Now, we can reuse xi and yi to test for x[n-1] < y[n-1], in which case we can immediately return false:

constisGreater=func(
{ in: [i32, i32], locals: [i32, i32], out: [i32] },
([x, y], [xi, yi]) => {
// set xi = x[n-1] and yi = y[n-1]
local.set(xi, loadLimb(x, n-1));
local.set(yi, loadLimb(y, n-1));
// return true if (xi > yi)
i32.gt_s(xi, yi);
if_(null, () => {
i32.const(1);
return_();
});
// return false if (xi != yi)
i32.ne(xi, yi);
if_(null, () => {
i32.const(0);
return_();
});
// TODO more logic before returning
i32.const(0); // return 0 (false)
}
);

Note that we can use x[n-1] $\ne$ y[n-1] instead of x[n-1] < y[n-1], because we already tested for x[n-1] > y[n-1] before.

If x[n-1] = y[n-1], neither of the if conditions are true, and we have to look at the second highest limb to determine whether x > y.

At this point, it’s clear that we could copy the same code for all limbs, until in the final case where x[0] = y[0], we know that x = y so that we return false.

For performance reasons, repeating all the instructions for each limb is exactly what we want. It will almost certainly be more performant than coding up a runtime loop in Wasm using the loop instruction. However, we don’t need to repeat those instructions one by one – after all, we’re using TS to put together the Wasm code, so we can just use a TS for loop to repeat instructions:

constisGreater=func(
{ in: [i32, i32], locals: [i32, i32], out: [i32] },
([x, y], [xi, yi]) => {
for (leti=n-1; i>=0; i--) {
// set xi = x[i] and yi = y[i]
local.set(xi, loadLimb(x, i));
local.set(yi, loadLimb(y, i));
// return true if (xi > yi)
i32.gt_s(xi, yi);
if_(null, () => {
i32.const(1);
return_();
});
// return false if (xi != yi)
i32.ne(xi, yi);
if_(null, () => {
i32.const(0);
return_();
});
}
// fall-through case: return false if x = y
i32.const(0);
}
);

Something interesting is happening here: It feels like we’re coding in two languages at once. The runtime code is Wasm, but there’s also compile-time logic, like the for loop, or the factoring out of loadLimb() into a helper, which is written in TS.

Low-level programming is compile-time programming

I argue that if you look at any implementation of low-level, performance-critical stuff, the tension between runtime and compile-time programming becomes obvious. In performance-critical code, it wastes time if you just evaluate if statements as needed or load a statically known constant from memory. You need to hard-code stuff as much as you can. A statically known constant should be baked into your code, as should the resolution of branches that depend on statically known parameters. Loops often should be unrolled.

As a consequence, code at this level almost never reads very nice. You’ll see a proliferation of compiler annotations, macros, and inline assembly, mixed in with normal runtime code. Check out the finite field code in arkworks or barrettenberg to see what I mean. In these examples, there is a proper runtime language on one hand (Rust and C), and then there’s a “compile-time language” layered on top which consists of annotations and macros, and tends to be rather ugly.

The flipside are code-generation schemes, where the nice part is the compile-time language, which could be some scripting language, while the runtime part is just code snippets represented as strings, stitched together by the scripting language, untyped and unloved.

It’s rare to find an example which combines both layers of programming in a readable and idiomatic way. I think the wasmati approach comes pretty close to that ideal. The compile-time language is TS. The runtime language is the library that you write instructions in, and it’s as good as that library is designed. In the case of wasmati, I tried hard to make TS types reflect the runtime types that Wasm has: i64, i32 and so on. The goal is that types in the compile-time language (TS) prevent you from writing invalid code in the runtime language (Wasm). This is sometimes possible, but not always, because sadly you can’t encode the validity of operations on an implicit stack in types, in a way that preserves a simple imperative API. Where types can’t help, good, as-local-as-possible error messages have to step in.

There are, of course, trade-offs. Eschewing a compiler means that you’re taking in a lot of complexity that the compiler would’ve taken care of, like memory management.

How fast is Wasm compared to JS?

I still have to prove that we gain anything by moving to Wasm. JS has native bigints – are we really going to improve over them?

To find out, we’ll implement finite field addition, and compare the performance against JS bigints. Here’s the baseline that we want to beat:

// add two numbers modulo p
functionaddBigint(x: bigint, y: bigint, p: bigint) {
letz=x+y;
returnz<p?z : z-p;
}

We’re adding two numbers, which are assumed to be non-negative and smaller than the modulus $p$. If the result is still less than $p$, it’s a valid field element and we return it. Otherwise, we know that the result must be less than $2p$, so we subtract $p$ once to bring it back within the range $0,…,p-1$.

Now the same logic in Wasm. I’ll just dump the full implementation here and tell you what it does at a high level:

// define constants
// the modulus
letp=0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001n;
letw=29; // limb size
letwordMax= (1<<w) -1;
letn=9; // number of limbs, so that 2^(n*w) > p
letP: Uint32Array=bigintToLimbs(p); // p as array of n limbs
/**
* add two field elements stored in memory
*/constadd=func(
{
in: [i32, i32, i32],
locals: [i32],
out: [],
},
([z, x, y], [zi]) => {
// z = x + y
for (leti=0; i<n; i++) {
// zi = x[i] + y[i] + carry
i32.add(loadLimb(x, i), loadLimb(y, i));
if (i>0) i32.add(); // add carry
local.set(zi);
// perform carry on zi and store in z[i];
// carry bit is left on the stack for next i
if (i<n-1) i32.shr_s(zi, w); // carry bit
storeLimb(z, i, i32.and(zi, wordMax));
}
// if (z < p) return;
isLower(z, zi, P);
if_(null, () =>return_());
// z -= p
for (leti=0; i<n; i++) {
// zi = z[i] - p[i] + carry
i32.sub(loadLimb(z, i), P[i]);
if (i>0) i32.add(); // add carry
local.set(zi);
// perform carry on zi and store in z[i];
// carry "bit" (0 or -1) is left on the stack for next i
if (i<n-1) i32.shr_s(zi, w); // carry "bit"
storeLimb(z, i, i32.and(zi, wordMax));
}
}
);
/**
* helper for checking that x < y, where x is stored in memory and y is
* a constant stored as an array of n limbs
*/functionisLower(x: Local<i32>, xi: Local<i32>, y: Uint32Array) {
block({ out: [i32] }, (block) => {
for (leti=n-1; i>=0; i--) {
// set xi = x[i]
local.set(xi, loadLimb(x, i));
// return true if (xi < yi)
i32.lt_s(xi, y[i]);
if_(null, () => {
i32.const(1);
br(block);
});
// return false if (xi != yi)
i32.ne(xi, y[i]);
if_(null, () => {
i32.const(0);
br(block);
});
}
// fall-through case: return false if z = p
i32.const(0);
});
}
functionloadLimb(x: Local<i32>, i: number) {
returni32.load({ offset: 4*i }, x);
}
functionstoreLimb(x: Local<i32>, i: number, xi: Input<i32>) {
returni32.store({ offset: 4*i }, x, xi);
}
functionbigintToLimbs(x: bigint, limbs=newUint32Array(n)) {
for (leti=0; i<n; i++) {
limbs[i] = Number(x&BigInt(wordMax));
x>>=BigInt(w);
}
returnlimbs;
}

At the beginning, we declare our prime number $p$ (as a bigint), and parameters w and n which determine the limb representation. We also call a method bigintToLimbs(p) to store $p$ as a Uint32Array of length $n$. Its limbs are used as hard-coded constants in our code.

The functions at the bottom are helpers. storeLimb() is similar to loadLimb() to store a field element limb back to memory.

isLower() is almost the same logic that we had before in isGreater(), except that one of its arguments is now a constant field element stored as limb array, like $p$. Also, isLower() is not its own func but just an inline block of code that can be called to be added to other functions. It uses a block instruction with a return type of [i32], to implement the early return logic in an inline way.

The add() function works roughly like this:

There’s an initial loop that computes z = x + y

For each limb, we do z[i] = x[i] + y[i]

At each step, we perform a carry on the result to bring it back to a size of w bits. The carry bit is found by a right shift z[i] >> w, and is simply left on the stack and added in the next loop iteration.

The remainder (z[i] without the carry bit) is found by a bitwise AND with bit mask 2^w - 1, and stored in memory.

We test if z < p, and if yes, return

If z >= p, we subtract p to get z - p

This loop is similar to the loop for addition, except that this time, one of the inputs is a constant: i32.sub(loadLimb(z, i), P[i]). Here, P[i] is a number. When passing plain numbers to instructions, wasmati fills in an i32.const(number) instruction.

Carrying works the same for negative values, as long as we use a signed right shift i32.shr_s, and the limb size $w < 32$. The carry “bit” in this case is either 0 or -1.

Using and running a Wasm function

The only thing left is to encode our function as a Wasm module, and run it. This is another step where using wasmati is much more convenient compared to writing raw WAT:

However, to run our add() function on bigints, we first have to write those bigints into wasm memory to get a pointer (the number that our function expects). At the end, we want to read the result back out into a bigint. We have helpers fromBigint() and toBigint() to perform that IO; you can see them in the full code example.

Which one is faster?

Here are the two implementations that are benchmarked. We just add a number to itself in a loop:

letx=randomField();
letN=1e7;
// warm up
for (leti=0; i<5; i++) {
benchWasm(x, N);
benchBigint(x, N);
}
// benchmark
console.time("wasm");
letz0=benchWasm(x, N);
console.timeEnd("wasm");
console.time("bigint");
letz1=benchBigint(x, N);
console.timeEnd("bigint");
// make sure that results are the same
if (z0!==z1) throw Error("wrong result");

The output pretty much settles the Wasm vs bigint question:

wasm: 103.538ms
bigint: 459.061ms

As you can see, the bigint version is more than 4x slower. Putting in the extra work to write optimized Wasm is definitely worth it if performance matters. For field multiplication, which is much more performance-critical than addition, I found the difference to be even larger because there are more optimization vectors.

That’s it!

I hope that you enjoyed this post and that it gave you some appetite for adding low-level Wasm to a TS code base. At zkSecurity, we will continue to improve wasmati and build on it to provide fast and correct Wasm implementations of cryptographic algorithms, as a public good for the zk community.

Contact us at [email protected] if you want to leverage our expertise for an audit of your crypto stack!