import torch
import torch.nn as nn
from vit.model import Attention, Transformer, Sigma
import numpy as np
from vit.generate_data import partition, make_data 
import math
import matplotlib.pyplot as plt

s = 4
n_channels = 10
n_batch = 5000
n_batch_test = 1000
N = n_batch
alpha = 0.1
p = 3


# Model

conv1 = torch.nn.Conv2d(n_channels, 1, kernel_size=2, stride=2,
                padding=0, dilation=1, groups=1, bias=True,
                padding_mode='zeros', device=None, dtype=None)

bn1 = torch.nn.BatchNorm2d(1)

conv_net = nn.Sequential(conv1, bn1)
nn.init.constant_(conv_net[0].weight, 1.)

X = torch.randn(n_batch, n_channels, s, s)
U = X.clone().detach().requires_grad_(True)
Z = conv_net(U)

y = torch.sgn(torch.mean(Z**3, (1, 2, 3))).detach()

X_test = torch.randn(n_batch_test, n_channels, s, s)
U_test = X_test.clone().detach().requires_grad_(True)
Z_test = conv_net(U_test)

y_test = torch.sgn(torch.mean(Z_test**3, (1, 2, 3))).detach()


D = s ** 2
d = n_channels

sigma = Sigma(alpha, p)

sigma_Q = math.log(math.log(d)) 

Q_0 = torch.eye(D) * sigma_Q
v_0 = torch.randn(d) * 0.2

Q = torch.nn.Parameter(Q_0)

v = torch.nn.Parameter(v_0)

net = Attention(Q, v)

X = X.view(n_batch, s**2, n_channels)
X_test = X_test.view(n_batch_test, s**2, n_channels)

print('shape of X: ', X.shape)

losss = []
error = []

optimizer = torch.optim.SGD(net.parameters(), lr=3e-1, weight_decay=1e-4)


n_epochs = 4000
print('Training for ', n_epochs, ' epochs')
for i in range(n_epochs):
    optimizer.zero_grad()
    Y = sigma(net(X))
    loss = torch.log(1 + torch.exp(-Y * y)).mean()
    loss.backward()
    optimizer.step()
    Y_test = sigma(net(X_test))
    
    if i % 1 == 0 :
        error.append((((torch.sgn(Y_test) - y_test).abs().sum() / (2 * n_batch_test)) * 100))
    if i % 1000 == 0 :
        np.save('metrics_conv/error.npy', error)
        np.save('metrics_conv/Q.npy', Q.detach())
        print('loss: ', loss.data)
        
np.save('metrics_conv/error_conv.npy', error)
np.save('metrics_conv/Q_conv.npy', Q.detach())


        