Let $H^r(\Omega) = W^{r, 2}(\Omega)$ be a Sobolev space for $r \in \mathbb{Z}$ for functions defined on a domain $\Omega$. Then for a function $H^r (\Omega) \ni f: \Omega \rightarrow \mathbb{R}$, the Sobolev norm $$ ||f||_{H^r (\Omega)} := \left( \sum _{i = 0} ^r ||f^{(i)}||^2_2 \right) ^{1/2}. $$ Implementing this requires composing up to $r$ derivatives and computing their norm. To accomplish this numerically, the following code evaluates the Sobolev norm over $\Omega = [0, 1]$ on a Gauss-Legendre quadrature in JAX.
import jax
import jax.numpy as jnp
from typing import Callable
from scipy.special import roots_jacobi
import matplotlib.pyplot as plt
plt.style.use('dark_background')
def gaussLegendre(a : float, b : float, ng : int):
x, w = roots_jacobi(ng, 0, 0)
x = jnp.reshape(x, [ng, 1])
w = jnp.reshape(w, [ng, 1])
x = 0.5*(b-a)*x + 0.5*(b+a)
w = 0.5*(b-a)*w
return x, w
def composegrad(n : int):
def fn(x):
for _ in range(n):
x = jax.grad(x)
return x
return fn
def Hr_norm(f : Callable, x : jnp.array, w: jnp.array, r : int):
if r == 0:
return jnp.sqrt(jnp.sum(w * f(x)**2))
norm = 0.
norm += jnp.sum(w * f(x)**2)
for n in range(1, r+1): # order inclusive
dndx = jax.vmap(composegrad(n)(f))(x)
norm += jnp.sum(w * dndx**2)
return jnp.sqrt(norm)
xgl, w = gaussLegendre(0, 1, 128)
xgl, w = xgl.squeeze(), w.squeeze()
x = jnp.linspace(0, 1, 100)
f = lambda x : x ** 3
print(Hr_norm(f, xgl, w, 0)**2)
print(Hr_norm(f, xgl, w, 1)**2)
print(Hr_norm(f, xgl, w, 2)**2)
print(Hr_norm(f, xgl, w, 3)**2)
dfdx = jax.vmap(jax.grad(f))
d2fdx = jax.vmap(composegrad(2)(f))
d3fdx = jax.vmap(composegrad(3)(f))
plt.plot(x, f(x), label=r"$f(x) = x^3$")
plt.plot(x, dfdx(x), label=r"$f'(x) = 3x^2$")
plt.plot(x, d2fdx(x), label=r"$f''(x) = 6x$")
plt.plot(x, d3fdx(x), label=r"$f'''(x) = 6$")
plt.legend()
plt.show()
The above is a simple example with $f(x) = x^3$ and $r \in \{0, 1, 2, 3 \}$.
Running the code results in the following plot, which serves as confirmation of the derivative composition working correctly.