import torch
import torch.nn as nn
from vit.model import Attention, Transformer, Sigma
import numpy as np
from vit.generate_data import partition, make_data
import math
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


D = 10
d = int(D ** 2)
n_batch = 50000 # important number of samples, this will be slow
N = n_batch
C = math.log(d) / 2
q = math.log(d) / D
L = int(D // C) + 1
w = torch.randn(d)
w /= w.square().sum().sqrt()
S = partition(D, L)
N_test = 1000

X, y, w, S = make_data(N, D, L, d, S, w)
torch.save(X, 'data/X.to')
torch.save(y, 'data/y.to')
torch.save(w, 'data/w.to')
torch.save(S, 'data/S.to')



X = torch.load('data/X.to').requires_grad_(True)
y = torch.load('data/y.to').long()
w = torch.load('data/w.to')
S = torch.load('data/S.to')


X_test, y_test, _, _= make_data(N_test, D, L, d, S, w)
X = X.to(device)
y = y.to(device)
w = w.to(device)
X_test = X_test.to(device)
y_test = y_test.to(device)

plt.figure(figsize=(4, 4))
S_matrix = torch.eye(D, D)
for x in S:
    S_matrix[x[0], x[1]] = 1/2
    S_matrix[x[1], x[0]] = 1/2
plt.matshow(S_matrix, cmap='Greys')
plt.axis('off')
plt.savefig('figures/S.pdf')


alpha = 0.03
p = 5

sigma = Sigma(alpha, p)

sigma_Q = math.log(math.log(d))

Q_0 = torch.eye(D) *  sigma_Q + torch.randn(D, D) * 0.001
v_0 = torch.randn(w.shape) * 0.001

Q = torch.nn.Parameter(Q_0)

v = torch.nn.Parameter(v_0)

net = Attention(Q, v)
net = net.to(device)
coss = []
losss = []
test_error = []


optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)
n_epochs = 3000

def cosine(a, b):
    return (torch.dot(a, b)) / torch.sqrt(torch.dot(a, a)*torch.dot(b, b))

### TRAINING : SLOW because we consider an important number of samples

for i in range(n_epochs):
    Y_test = sigma(net(X_test))
    test_error.append(((torch.sgn(Y_test) - y_test).abs().sum() / (2 * N_test)) * 100)
    optimizer.zero_grad()
    Y = sigma(net(X))
    loss = torch.log(1 + torch.exp(- Y*y)).mean()
    coss.append(cosine(v.data, w))
    losss.append(loss.data)
    loss.backward()
    optimizer.step()
    Q_store = Q.clone().data
    for j in range(Q.shape[0]):
        Q_store[j, j] = sigma_Q
    Q.data = Q_store
    if i % 100 == 0:
        np.save('metrics/test_errors.npy', test_error)
        np.save('metrics/cosines.npy', coss)
        np.save('metrics/A_realistic.npy', Q.detach())
        print(loss)

