Compiling a loop on MPS - exceeding limit of buffers

Hi,

I’m working on compiling a loop on MPS device, but I’m hitting an error “error: number of constant buffers exceeds maximum supported (31)” while compiling the backward pass. Forward pass compiles fine. I realize this is currently a limitation of Apple Metal, but given that compilation of the forward pass works fine, I feel it should be also feasible to compile the backward pass. I guess it’s due to how torch inductor works, not taking into account the buffer limit. Is there currently some way to fix this problem? (I tested this on cuda with up to 1000 iterations and it works fine there)

import torch


class Model(torch.nn.Module):
    def __init__(self, n_iter):
        super().__init__()
        self.linear = torch.nn.Linear(64, 64)
        self.n_iter = n_iter

    @torch.compile(fullgraph=False)
    def forward(self, x):
        for i in range(self.n_iter):
            x = torch.sin(self.linear(x))

        return x


def main(device, n_iter):
    model = Model(n_iter=n_iter).to(device)

    x = torch.nn.Parameter(torch.randn(8, 64, device=device))
    target = torch.randn(8, 64, device=device)

    y = model(x)
    loss = (y - target).pow(2).mean()
    loss.backward()  # -> buffer limit error upon compilation
    print('Success')


if __name__ == '__main__':
    main(device='mps', n_iter=100)