import torch
from self_attention_layer import *
from create_data import *
from cl_loss import *
import seaborn as sns
import matplotlib.pyplot as plt


torch.manual_seed(202)  
torch.cuda.manual_seed(202) 
torch.backends.cudnn.deterministic = True


# you may need to train a model first and then save the model and data, see cross_training.py
# here we save the model name as "one_layer_tf_direct_11_1_1200_16_1_16384_100_2023-09-28.pth", and the data name as "one_layer_tf_direct_11_1_1200_16_1_16384_100_2023-09-28.pt"
# or you can define these following parameters directly, and just need to ensure that these parameters remain consistent with the values used during training.  

# model_name = "models&datas/one_layer_tf_direct_11_1_1200_16_1_16384_100_2023-09-28.pth"
model_name = 'your path'
parameters = model_name.split("_")
datatype, readout_type, input_dim, output_dim, num_random, batch_size, task_num, demon_num, epochs, date =\
    parameters[2], parameters[3], int(parameters[4]), int(parameters[5]), int(parameters[6]), \
    int(parameters[7]), int(parameters[8]), int(parameters[9]), int(parameters[10]), parameters[11]

token_dim = input_dim + output_dim

fast_model = one_layer_fast_transformer(token_dim, output_dim, token_dim*4, num_random, readout_type)
model_dict = torch.load(model_name)
fast_model.load_state_dict(model_dict)
print(" model loaded successfully!  ")

W_q, W_k, W_v = fast_model.sa_layer.get_weight()
W_q, W_k, W_v = W_q.data, W_k.data, W_v.data

fast_model = one_layer_fast_transformer(token_dim, output_dim, token_dim*4, num_random=1200, readout_type = "direct")
fast_model.sa_layer.init_weight(W_q, W_k, W_v)

#%%


datatype = 'linear'
data_name = model_name[:-1]
data = torch.load(data_name)
token_test = data['token_test']

tokens = token_test.clone()
tokens = tokens.transpose(0,1)      # tokens: [token_dim, token_num] --》 [token_num, token_dim]
tokens[-1, input_dim:] = 0

prediction, attention = fast_model(tokens)
V, attention, attention_exact = fast_model.sa_layer.get_V_attention()
output, output_exact = fast_model.sa_layer.get_output() # here we use Performer to approximate softmax attention 

#%% whether equivalent

projectionmatrix = fast_model.sa_layer.projectionmatrix
train_x, train_y_std, test_x, D, W_0 =  data_transformation(tokens, batch_size-1, W_q, W_k, W_v, projectionmatrix)
W_0 = Parameter(W_0)

lr = 0.005
ref_linear = linear_model(num_random, token_dim, weight=W_0)
loss = GDloss(constant=D, lr=lr, norm=False)
optim = torch.optim.SGD(ref_linear.parameters(), lr=lr)

y_pred = ref_linear(train_x)
step_loss = loss(train_y_std, y_pred)
optim.zero_grad()
step_loss.backward()
optim.step()

y_test_pred_linear = ref_linear(test_x)

print("GD result: " + str(y_test_pred_linear))
print("ICL reference: " + str(prediction[-1]))

gd_batch_size = 1

projectionmatrix = fast_model.sa_layer.projectionmatrix
train_x, train_y_std, test_x, D, W_0 =  data_transformation(tokens, batch_size-1, W_q, W_k, W_v, projectionmatrix)
W_0 = Parameter(W_0)

gd_train_set = GD_Data(train_x, train_y_std)
setup_seed(41)
gd_train_loader = DataLoader(gd_train_set, batch_size= gd_batch_size, shuffle = True)
gd_batch_num = len(gd_train_loader)

ref_linear = linear_model(num_random, token_dim, weight=W_0)

gd_epochs = 1
loss_lr = 0.005
gd_lr = loss_lr # here we just need to make sure gr_lr = loss_lr, as the proof of theorem 1 shows.

loss = GDloss(constant=D, lr=loss_lr, norm=False)
optim = torch.optim.SGD(ref_linear.parameters(), lr=gd_lr)

h_true = prediction[-1].detach().numpy()
res_predict = []
res_loss = []
res_step_loss = []
res_step_predict = []
res_step_norm = []
res_epoch_norm = []


for epoch in range(gd_epochs):
    
    epoch_loss = 0
    step = 0
    for datas in gd_train_loader:
        train_x, train_y_std = datas   
        y_pred = ref_linear(train_x)
        step_loss = loss(train_y_std, y_pred)
        optim.zero_grad()
        step_loss.backward()
        optim.step()
        
        epoch_loss += step_loss/gd_batch_num
        step = step + 1
        res_step_loss.append(step_loss)
        
        with torch.no_grad():
            y_test_pred = ref_linear(test_x)
            res_step_predict.append(y_test_pred[-1,-1].detach().numpy())
            norm = np.linalg.norm(y_test_pred[-1,:].detach().numpy() - h_true)
            res_step_norm.append(norm)

    res_loss.append(epoch_loss)
    res_epoch_norm.append(np.linalg.norm(y_test_pred[-1,:].detach().numpy() - h_true))
    
    with torch.no_grad():
        y_test_pred = ref_linear(test_x)
        res_predict.append(y_test_pred[-1,-1].detach().numpy())


#%% plot the norm vs steps

x_zeros = np.arange(len(res_step_norm)*1.1)
y_zeros = np.zeros(len(x_zeros))
lengths_epoch = (np.arange(gd_epochs)+1)*gd_batch_num-1
lengths_steps =  np.arange(len(res_step_loss))
y_epoch_ref = np.arange(0, res_step_norm[0], 0.01)
x_epoch_ref = np.ones(len(y_epoch_ref)) * lengths_steps[-1]

plt.figure(num=2, dpi = 600, figsize=(6,5))
l4, = plt.plot(x_epoch_ref+1 , y_epoch_ref, color = 'C7',  linestyle = '--',linewidth = 4)
l1, = plt.plot(x_zeros, y_zeros, color = 'lightsteelblue', linestyle = '--', linewidth = 4)
l3, = plt.plot(lengths_steps+1,res_step_norm, color = '#4e79a7', linewidth = 4, marker = 'o', ms = 1)

plt.grid(which='both', axis='both', linestyle = '--', linewidth = 0.5 )
plt.tick_params(labelsize=15)
ax=plt.gca()

ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(1.5);##
ax.spines['left'].set_linewidth(1.5);##
ax.tick_params(axis='both', which='both', labelsize=15, width=1.5) 

ax.set_xlim([1, lengths_steps[-1]*1.1])
# ax.set_ylim([-0.03 ,0.8])

plt.title('$\| \hat{y}_{test} - h^{\'}_{T+1}\|_{2}$ at each step',size = 22)
plt.xlabel('steps', size = 22)
plt.ylabel('Norm', size = 22)
plt.show()