tt-flake/pkgs/buda-prebuilt/test.py
2024-11-01 00:21:03 +02:00

27 lines
838 B
Python

import pybuda
import torch
# Sample PyTorch module
class PyTorchTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weights1 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
self.weights2 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
def forward(self, act1, act2):
m1 = torch.matmul(act1, self.weights1)
m2 = torch.matmul(act2, self.weights2)
return m1 + m2, m1
def test_module_direct_pytorch():
input1 = torch.rand(4, 32, 32)
input2 = torch.rand(4, 32, 32)
# Run single inference pass on a PyTorch module, using a wrapper to convert to PyBUDA first
output = pybuda.PyTorchModule("direct_pt", PyTorchTestModule()).run(input1, input2)
print(output)
if __name__ == "__main__":
test_module_direct_pytorch()