from create_data import *
import datetime
import time
from torch.utils.data import DataLoader
from self_attention_layer import *
from cl_loss import *
import seaborn as sns
import matplotlib.pyplot as plt
from one_for_all import *

device = "cuda" if torch.cuda.is_available() else "cpu"  # model, loss, dataset
device = "cpu"

#%%

setup_seed(41)

datatype = "linear"  # linear, cos, exp

para_dict = {'regu': False, 'alpha':0.5, 'nega': False, 'beta_0': 0.1, 'nega_k': 3,
            'augu': False, 'augu_k': True, 'augu_v': False}

# pretraining parameters
input_dim, output_dim = 11, 1
demon_num, batch_size = 16384, 16
num_random_ratio = 100
task_num = 1
epochs = 100
readout_type = "direct"  # "direct", "mlp", linear
lr = 0.005

batchnum_for_onetask = demon_num / batch_size
print("batch number for one task is :" + str(batchnum_for_onetask))

test_num = batch_size * 1 
# w, token_training, token_test = create_multitask_regression_data(task_num, demon_num, test_num, input_dim, output_dim )
w, token_training, token_test = create_multitask_data(task_num, demon_num, test_num, input_dim, output_dim, datatype = datatype )

pretrain_set = token_Data(w, token_training, input_dim = input_dim, output_dim = output_dim)
pretrain_loader = DataLoader(pretrain_set, batch_size = batch_size)  # demon_num % batch_size == 0
batch_num = len(pretrain_loader)   # batch_num = task_num * batchnum_for_onetask

pretrain_test_set = token_Data(w, token_test, input_dim, output_dim)
pretrain_test_loader = DataLoader(pretrain_test_set, batch_size = batch_size)  # demon_num % batch_size == 0)
batch_test_num = len(pretrain_test_loader)

#%%

token_dim = input_dim + output_dim
num_random = token_dim * num_random_ratio

one_layer_tf = self_fast_attention_layer_all(token_dim, output_dim, num_random, para_dict)

# loss = nn.MSELoss().to(device)
loss = pretrain_loss(weight = 1)
optim = torch.optim.SGD(one_layer_tf.parameters(), lr= lr)
# optim = torch.optim.Adam(one_layer_tf.parameters(),lr = lr)
res_epoch_loss = []

for epoch in range(epochs):

    epoch_loss = 0
    step = 0

    start_time = time.time()
    for datas in pretrain_loader:

        tokens, labels = datas  # data: [batch_size, input_dim + output_dim]   label: [batch_size, output_dim]
        tokens = tokens.to(device)
        labels = labels.to(device)
        fix_tokens = tokens.clone()
        fix_tokens[-1, input_dim:] = 0
        predictions, attention = one_layer_tf(fix_tokens)
        step_loss = loss(predictions, tokens)
        optim.zero_grad()   # clean the grad information
        step_loss.backward()    #  calculate the grad information
        optim.step()    # update the weights

        epoch_loss += step_loss/batch_num
        step = step + 1
    
    end_time = time.time()
    train_info = "epoch{} training loss is ".format(epoch) + str(epoch_loss.item())   
    res_epoch_loss.append(epoch_loss.detach().numpy())
    
    one_layer_tf.eval()
    with torch.no_grad():

        epoch_loss_test = 0
        step_test = 0
        for datas in pretrain_test_loader:
            tokens, labels = datas
            tokens = tokens.to(device)
            labels = labels.to(device)

            fix_tokens = tokens.clone()
            fix_tokens[-1, input_dim:] = 0
            predictions, attention = one_layer_tf(fix_tokens)
            step_loss = loss(predictions, tokens)
            epoch_loss_test += step_loss/batch_test_num
            step_loss += 1
            # break

        print(predictions[:,-1].reshape(-1))
        # print(fix_tokens[:,-1].reshape(-1))
        print(labels.reshape(-1))

    print(train_info)
    
#%%
name = 'model_name'
model_name = './cross_data/cross_models/'+ name + '.pt'
data_name = './cross_data/cross_res/'+ name + '.pt'
data = {'para':para_dict, 'w': w, 'token_training': token_training, 'token_test': token_test, 'res_epoch_loss': res_epoch_loss}
torch.save(data, data_name)    
torch.save(one_layer_tf.state_dict(), model_name)



