• Docs >
  • Extending torch.func with autograd.Function
Shortcuts

Extending torch.func with autograd.Function

So you’d like to use torch.autograd.Function with the torch.func transforms like torch.vmap(), torch.func.grad(), etc.

There are two main use cases:

  • you wish to call code that does not contain PyTorch operations and have it work with function transforms. That is, the torch.autograd.Function’s forward/backward/etc calls into functions from other systems like C++, CUDA, numpy.

  • you wish to specify custom gradient rules, like JAX’s custom_vjp/custom_jvp

PyTorch combines both of these concepts into torch.autograd.Function.

Basic Usage

This guide assumes you are familiar with Extending torch.autograd, which explains how to use torch.autograd.Function.

torch.autograd.Function can either have a forward() that accepts a ctx object, or it can have separate forward() (that does not accept ctx) and a setup_context() staticmethod that modifies the ctx object.

Only the latter is supported with function transforms:

  • forward() is the code that performs the operation and it should not accept a ctx object.

  • setup_context(ctx, inputs, output) is the code where you can call methods on ctx. Here is where you should save Tensors for backward (by calling ctx.save_for_backward(*tensors)), or save non-Tensors (by assigning them to the ctx object).

Any intermediates that need to be saved must be returned as an output from forward().

Depending on the transform,

In order for the torch.autograd.Function to be arbitrarily composable with function transforms, we recommend that all other staticmethods other than forward() and setup_context() must be transformable: that is, they must consist of only PyTorch operators or call other torch.autograd.Function (that may call into C++/CUDA/etc).

Let’s go over some examples of common use cases.

Example 1: autograd.Function calls into another system

A common case is a torch.autograd.Function with both forward() and backward() calling into another system (like C++, CUDA, numpy, triton).

import torch
import numpy as np

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    # Note that forward does not take ctx
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        # Any intermediates to be saved in backward must be returned as
        # outputs.
        return (
            # The desired output
            torch.tensor(result, device=device),
            # intermediate to save for backward
            torch.tensor(ind, device=device),
            # intermediate to save for backward
            torch.tensor(ind_inv, device=device),
        )

    # setup_context is responsible for calling methods and/or assigning to
    # the ctx object. Please do not do additional compute (e.g. add
    # Tensors together) in setup_context.
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        # Note that output is whatever you returned from forward.
        # If you returned multiple values, then output is a Tuple of multiple values.
        # If you returned a single Tensor, then output is a Tensor.
        # If you returned a Tuple with a single Tensor, then output is a
        # Tuple with a single Tensor.
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        # Tensors must be saved via ctx.save_for_backward. Please do not
        # assign them directly onto the ctx object.
        ctx.save_for_backward(ind, ind_inv)
        # Non-tensors may be saved by assigning them as attributes on the ctx object.
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        # For the autograd.Function to be arbitrarily composable with function
        # transforms, all staticmethod other than forward and setup_context
        # must be implemented in a "transformable" way; that is, they must
        # only consist of PyTorch operations or autograd.Function.
        #
        # For example, this allows us to do double backwards and/or compute
        # second order gradients.
        #
        # We've written the backward pass of NumpySort in terms of another
        # autograd.Function, NumpyTake.
        ind, ind_inv = ctx.saved_tensors
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

Now, to make it easier to use NumpySort (to hide away the intermediates we returned as outputs, as well as allow default args and kwargs), we create a new function that invokes it:

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

And here’s a sanity check:

x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))

Example 2: autograd.Function specifies custom gradient rules

Another common case is an torch.autograd.Function that is implemented with PyTorch operations. PyTorch is able to compute gradients for PyTorch operations automatically, but perhaps we wish to customize how the gradients are computed. Some reasons why we may want a custom backward different from the one PyTorch gives us are:

  • improving numeric stability

  • changing the performance characteristics of the backward

  • changing how edge cases are handled (e.g. nans, inf)

  • modifying the gradient (e.g. gradient clipping)

Here’s an example of an torch.autograd.Function for the function y = x ** 3 where we change the performance characteristics (some computation that would normally happen during the backward pass, computing dx, happens in the forward pass).

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        result = x ** 3
        # In regular PyTorch, if we had just run y = x ** 3, then the backward
        # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
        # that computation here in the forward pass instead.
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`.
        result = grad_output * dx + grad_dx * 6 * x
        return result

Now, to make it easier to use NumpySort (and hide away the intermediates we returned as outputs) we create a new function that invokes it:

def my_cube(x):
    result, _ = MyCube.apply(x)
    return result

Here’s a sanity check computing the second-order gradients:

x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)

Limitations and gotchas

Warning

Please read these limitations of torch.autograd.Function with torch.func transforms carefully. We are not able to catch many of these situations and error out gracefully so they will lead to undefined behavior.

Please do not capture Tensors that are being transformed over, have requires_grad=True, or are dual tensors, into the methods of the torch.autograd.Function. The way to be completely safe is to ensure that the only Tensors being used inside any method of the torch.autograd.Function must be directly passed as inputs (or via the ctx object) rather than come from outside the torch.autograd.Function.

torch.autograd.Function does not handle Tensors in pytrees (arbitrary nested Python data structures that may or may not contain Tensors). For those Tensors to be tracked by autograd, they must be passed directly as an argument to torch.autograd.Function. This is in contrast to jax.{custom_vjp, custom_jvp}, which do accept pytrees.

Please only use save_for_backward() or save_for_forward() to save Tensors. Please do not assign Tensors or collections of Tensors directly onto the ctx object - these Tensors will not get tracked

torch.vmap() Support

To use an torch.autograd.Function with torch.vmap(), you must either:

Automatically generate a vmap rule

If your torch.autograd.Function fulfills the following additional constraints, then we are able to generate a vmap rule for it. If it doesn’t fulfill the constraints or if you want custom behavior under vmap, please manually define a vmap staticmethod (see next section).

Warning

We are not easily able to check for the following constraints and error out gracefully. Violation of the constraints may lead to undefined behavior.

Example:

class MyCube(torch.autograd.Function):
    # Set generate_vmap_rule to True to ask PyTorch to automatically generate
    # a vmap rule.
    generate_vmap_rule = True

    @staticmethod
    def forward(x):
        result = x ** 3
        dx = 3 * x ** 2
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        result = grad_output * dx + grad_dx * 6 * x
        return result

def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)

Defining the vmap staticmethod

If your torch.autograd.Function calls into another system (like NumPy, C++, CUDA, triton), then to get it to work with torch.vmap() or transforms that use it, you’ll need to manually define a vmap() staticmethod.

Depending on what transforms you want to use and your use case, you may not need to add a vmap() staticmethod to all of your torch.autograd.Function:

We do recommend ensuring all of your torch.autograd.Function have support for torch.vmap() though, especially if you are writing a third-party library and you want your torch.autograd.Function to work with all combinations of torch.func() transforms.

Conceptually, the vmap staticmethod is responsible for defining how the forward() should behave under torch.vmap(). That is, it defines how to transform the forward() to run over inputs with an additional dimension (the dimension being vmapped over). This is similar to how torch.vmap() is implemented over PyTorch operations: for each operation, we define a vmap rule (sometimes also referred to as a “batching rule”).

Here’s how to define the vmap() staticmethod:

  • the signature is vmap(info, in_dims: Tuple[Optional[int]], *args), where *args is the same as the args to forward().

  • The vmap staticmethod is responsible for defining how the forward() should behave under torch.vmap(). That is, given inputs with an additional dimension (specified by in_dims), how do we compute the batched version of forward()?

  • For each arg in args, in_dims has a corresponding Optional[int]. It is None if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.

  • info is a collection of additional metadata that may be helpful: info.batch_size specifies the size of the dimension being vmapped over, while info.randomness is the randomness option that was passed to torch.vmap().

  • The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure as output and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Example:

def to_numpy(tensor):
    return tensor.cpu().numpy()

class NumpySort(torch.autograd.Function):
    @staticmethod
    def forward(x, dim):
        device = x.device
        x = to_numpy(x)
        ind = np.argsort(x, axis=dim)
        ind_inv = np.argsort(ind, axis=dim)
        result = np.take_along_axis(x, ind, axis=dim)
        return (
            torch.tensor(result, device=device),
            torch.tensor(ind, device=device),
            torch.tensor(ind_inv, device=device),
        )

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, dim = inputs
        _, ind, ind_inv = output
        ctx.mark_non_differentiable(ind, ind_inv)
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output, _0, _1):
        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None

    # The signature of the vmap staticmethod is:
    # vmap(info, in_dims: Tuple[Optional[int]], *args)
    # where *args is the same as the arguments to `forward`.
    @staticmethod
    def vmap(info, in_dims, x, dim):
        # For every input (x and dim), in_dims stores an Optional[int]
        # that is:
        # - None if the input is not being vmapped over or if the input
        #   is not a Tensor
        # - an integer if the input is being vmapped over that represents
        #   the index of the dimension being vmapped over.
        x_bdim, _ = in_dims

        # A "vmap rule" is the logic of how to perform the operation given
        # inputs with one additional dimension. In NumpySort, x has an
        # additional dimension (x_bdim). The vmap rule is simply
        # to call NumpySort again but pass it a different `dim`.
        x = x.movedim(x_bdim, 0)
        # Handle negative dims correctly
        dim = dim if dim >= 0 else dim + x.dim() - 1
        result = NumpySort.apply(x, dim + 1)

        # The vmap rule must return a tuple of two things
        # 1. the output. Should be the same amount of things
        #    as returned by the forward().
        # 2. one Optional[int] for each output specifying if each output
        # is being vmapped over, and if so, the index of the
        # dimension being vmapped over.
        #
        # NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
        # dimension being vmapped over to the front of `x`, that appears at
        # dimension 0 of all outputs.
        # The return is (output, out_dims) -- output is a tuple of 3 Tensors
        # and out_dims is a Tuple of 3 Optional[int]
        return NumpySort.apply(x, dim + 1), (0, 0, 0)

class NumpyTake(torch.autograd.Function):
    @staticmethod
    def forward(x, ind, ind_inv, dim):
        device = x.device
        x = to_numpy(x)
        ind = to_numpy(ind)
        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, ind, ind_inv, dim = inputs
        ctx.save_for_backward(ind, ind_inv)
        ctx.dim = dim

    @staticmethod
    def backward(ctx, grad_output):
        ind, ind_inv = ctx.saved_tensors
        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
        return result, None, None, None

    @staticmethod
    def vmap(info, in_dims, x, ind, ind_inv, dim):
        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims

        # The strategy is: expand {x, ind, ind_inv} to all have the dimension
        # being vmapped over.
        # Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).

        # Handle negative dims by wrapping them to be positive
        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
        dim = dim if dim >= 0 else dim + logical_dim

        def maybe_expand_bdim_at_front(x, x_bdim):
            if x_bdim is None:
                return x.expand(info.batch_size, *x.shape)
            return x.movedim(x_bdim, 0)

        # If the Tensor doesn't have the dimension being vmapped over,
        # expand it out. Otherwise, move it to the front of the Tensor
        x = maybe_expand_bdim_at_front(x, x_bdim)
        ind = maybe_expand_bdim_at_front(ind, ind_bdim)
        ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)

        # The return is a tuple (output, out_dims). Since output is a Tensor,
        # then out_dims is an Optional[int] (instead of being a Tuple).
        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0

def numpy_sort(x, dim=-1):
    result, _, _ = NumpySort.apply(x, dim)
    return result

x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))

Note

The vmap staticmethod should aim to preserve the semantics of the entire Function. That is, (pseudocode) grad(vmap(MyFunc)) should be replaceable with a grad(map(MyFunc)).

If your autograd.Function has any custom behavior in the backward pass, please keep this in mind.

Note

It is a legitimate use case to write a custom vmap staticmethod for a Function that PyTorch is able to generate a vmap rule for via generate_vmap_rule=True. You may wish to do this if the generated vmap rule doesn’t have the semantics you’re looking for.

torch.func.jvp() Support

To support forward-mode AD, a torch.autograd.Function must have a jvp() staticmethod. Please see Forward mode AD for details.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources