Let $H^r(\Omega) = W^{r, 2}(\Omega)$ be a Sobolev space for $r \in \mathbb{Z}$ for functions defined on a domain $\Omega$. Then for a function $H^r (\Omega) \ni f: \Omega \rightarrow \mathbb{R}$, the Sobolev norm $$ ||f||_{H^r (\Omega)} := \left( \sum _{i = 0} ^r ||f^{(i)}||^2_2 \right) ^{1/2}. $$ Implementing this requires composing up to $r$ derivatives and computing their norm. To accomplish this numerically, the following code evaluates the Sobolev norm over $\Omega = [0, 1]$ on a Gauss-Legendre quadrature in JAX.

import jax 
import jax.numpy as jnp 
from typing import Callable
from scipy.special import roots_jacobi
import matplotlib.pyplot as plt 
plt.style.use('dark_background')


def gaussLegendre(a : float, b : float, ng : int): 
    x, w = roots_jacobi(ng, 0, 0) 
    x = jnp.reshape(x, [ng, 1])
    w = jnp.reshape(w, [ng, 1])

    x = 0.5*(b-a)*x + 0.5*(b+a)
    w = 0.5*(b-a)*w

    return x, w

def composegrad(n : int): 
    def fn(x):
        for _ in range(n): 
            x = jax.grad(x) 
        return x 
    return fn 


def Hr_norm(f : Callable, x : jnp.array, w: jnp.array, r : int): 
    if r == 0:  
        return jnp.sqrt(jnp.sum(w * f(x)**2))

    norm = 0.
    norm += jnp.sum(w * f(x)**2)
    for n in range(1, r+1): # order inclusive  
        dndx = jax.vmap(composegrad(n)(f))(x) 
        norm += jnp.sum(w * dndx**2)

    return jnp.sqrt(norm)        

xgl, w = gaussLegendre(0, 1, 128)
xgl, w = xgl.squeeze(), w.squeeze()
x = jnp.linspace(0, 1, 100) 

f = lambda x : x ** 3
print(Hr_norm(f, xgl, w, 0)**2)
print(Hr_norm(f, xgl, w, 1)**2)
print(Hr_norm(f, xgl, w, 2)**2)
print(Hr_norm(f, xgl, w, 3)**2)

dfdx = jax.vmap(jax.grad(f))
d2fdx = jax.vmap(composegrad(2)(f))
d3fdx = jax.vmap(composegrad(3)(f))

plt.plot(x, f(x), label=r"$f(x) = x^3$")
plt.plot(x, dfdx(x), label=r"$f'(x) = 3x^2$")
plt.plot(x, d2fdx(x), label=r"$f''(x) = 6x$")
plt.plot(x, d3fdx(x), label=r"$f'''(x) = 6$")

plt.legend()

plt.show()

The above is a simple example with $f(x) = x^3$ and $r \in \{0, 1, 2, 3 \}$.

Running the code results in the following plot, which serves as confirmation of the derivative composition working correctly.