-
Notifications
You must be signed in to change notification settings - Fork 63
/
gru_torch.py
57 lines (46 loc) · 1.48 KB
/
gru_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import json
from json import JSONEncoder
class EncodeTensor(JSONEncoder,Dataset):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().detach().numpy().tolist()
return super(json.NpEncoder, self).default(obj)
np.random.seed(1001)
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.gru = torch.nn.GRU(1, 8)
self.gru2 = torch.nn.GRU(8, 8, num_layers=3)
self.dense = torch.nn.Linear(8, 1)
def forward(self, torch_in):
x, _ = self.gru(torch_in)
x, _ = self.gru2(x)
return self.dense(x)
x = np.random.uniform(-1, 1, 1000)
torch_in = torch.from_numpy(x.astype(np.float32)).reshape(-1, 1)
model = Model()
y = model.forward(torch_in).detach().numpy()
print(np.shape(y))
plt.plot(x)
plt.plot(y[0, :])
plt.show()
np.savetxt('test_data/gru_torch_x_python.csv', x, delimiter=',')
np.savetxt('test_data/gru_torch_y_python.csv', y, delimiter=',')
with open('models/gru_torch.json', 'w') as json_file:
json.dump(model.state_dict(), json_file,cls=EncodeTensor)
# print(x[:5])
# print(conv.state_dict())
#
# ch_idx = 0
# kernel_test = conv.state_dict()["weight"][ch_idx, 0, :].detach().numpy()
# print(kernel_test)
# y_test = np.correlate(x, kernel_test, mode='full')
# print(y_test[:10])
# print(y[ch_idx,:10])
#
# print(np.sum(kernel_test * x[:5]))