27 lines
838 B
Python
27 lines
838 B
Python
![]() |
import pybuda
|
||
|
import torch
|
||
|
|
||
|
|
||
|
# Sample PyTorch module
|
||
|
class PyTorchTestModule(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.weights1 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
|
||
|
self.weights2 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
|
||
|
def forward(self, act1, act2):
|
||
|
m1 = torch.matmul(act1, self.weights1)
|
||
|
m2 = torch.matmul(act2, self.weights2)
|
||
|
return m1 + m2, m1
|
||
|
|
||
|
|
||
|
def test_module_direct_pytorch():
|
||
|
input1 = torch.rand(4, 32, 32)
|
||
|
input2 = torch.rand(4, 32, 32)
|
||
|
# Run single inference pass on a PyTorch module, using a wrapper to convert to PyBUDA first
|
||
|
output = pybuda.PyTorchModule("direct_pt", PyTorchTestModule()).run(input1, input2)
|
||
|
print(output)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_module_direct_pytorch()
|