import sys, numpy as np
sys.path.append('..')
from main import *
import torch
from utils import *
# from model import *
from data import *
import argparse
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from model.GPT_DataAnalysis import *
from types import SimpleNamespace

def pca(X0):
    pca = PCA(n_components=2)
    X = pca.fit_transform(X0)
    return X

def tsne(X0):
    tsne = TSNE(n_components=2)
    X = tsne.fit_transform(X0)
    return X

# def cosine_similarity(v1, v2):
#     return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

def cosine_similarity_array(X):
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    return np.dot(X, X.T)


def load_args(args_path):
    args = read_json_data(args_path)
    args = argparse.Namespace(**args)

    return args

# 
def last_word_acc(model, data, batch_size = -1):
    # ，
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    data = np.array(data)
    if batch_size == -1:
        batch_size = np.size(data, 0)
    train_dataset = MyDataSet(data)
    data_loader = Data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size, 
                                        drop_last=False, collate_fn=train_dataset.padding_batch)
    
    for i, (dec_inputs, dec_outputs) in enumerate(data_loader):  
        dec_inputs, dec_outputs = dec_inputs.to(device), dec_outputs.to(device)
        outputs, _ = model(dec_inputs) 
        outputs = outputs.argmax(axis=-1).view(-1, np.size(data, 1)-1)
        correct += (outputs[:, -1] == dec_outputs[:, -1]).sum().item()

    return correct / len(data_loader.dataset) 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch_list", nargs='+', type=int, default=[800], help="epoch list")
    parser.add_argument("--working_list", type=str, default='./model/model_800.pt', help="model path")
    parser.add_argument("--testing_seq_list", nargs='+', type=int, default=[[3, 4, 25, 36, 50, 33, 43, 92, 87]], help="testing seq list")
    parser.add_argument("--testing_seq_model_part", type=str, default='0_softmaxQK1', help="testing seq model part")
    parser.add_argument("--testing_seq_position", type=int, default=1, help="testing seq position")
    parser.add_argument("--testing_seq_anchor", nargs='*', type=int, default=3)
    
    for epoch in [800]:
        state_dict=torch.load(f'./model/model_{epoch}.pt', map_location=torch.device('cpu'))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        args = load_args(f'./config.json')

        model = GPT_DataAnalysis(args, device)
        model.load_state_dict(state_dict)
        model.to(device)
        seq = np.array([[3, 4, 25, 36, 50, 33, 43, 92, 87]])

        # “_”seq
        seq_str = '_'.join([str(i) for i in list(seq[0])])
        dec_logits = model(torch.tensor(seq).to(device))

        datas = load_GPT_data(model, state_dict)

        print(seq_str)
        # print(np.argmax(datas['3_out2'][-1]))
        plot_single_show_number(datas['0_softmaxQK1'], save_path=f'./epoch_{epoch}_{seq_str}_0_softmaxQK1.png')
        
        seq = np.array([[4, 3, 25, 36, 50, 33, 43, 92, 87]])

        # “_”seq
        seq_str = '_'.join([str(i) for i in list(seq[0])])
        dec_logits = model(torch.tensor(seq).to(device))

        datas = load_GPT_data(model, state_dict)

        print(seq_str)
        # print(np.argmax(datas['3_out2'][-1]))
        plot_single_show_number(datas['0_softmaxQK1'], save_path=f'./epoch_{epoch}_{seq_str}_0_softmaxQK1.png')
        
        # 
        
        model_path = './model/model_800.pt'
        # model_json_path = './result/GPT_composition_34_seen_43_unseen_20231213/composition-seed_1-N_20000-4_2head/model/model_3000.json'
        args = read_json_data("./config.json")
        #  simpleNamespace
        args = SimpleNamespace(**args)
        # 
        model = GPT_DataAnalysis(args, device='cpu')
        state_dict:dict = torch.load(model_path, map_location='cpu')
        
        # 
        model.load_state_dict(state_dict)
        def get_attention_output(module, input, output):
            #  forward 
            #  output  module 
            module.attention_output = output

        #  model.transformer.encoder.layers[0]  attention block
        handle = model.decoder.register_forward_hook(get_attention_output)

        #  [20, 100]  tensor
        data_size = 1000
        seq_len = 9
        data = torch.randint(20, 100, (data_size, seq_len))
        data[:, 2] = 1
        data[:, 3] = 2
        # data[:, 2] = 3
        # data[:, 3] = 4
        labels = (data[:, 1] + 6).numpy()
        # print(np.concatenate((data.numpy(), labels.reshape(-1, 1)), axis=1))
        # data = torch.tensor([[38,  31,  4,   4,  48,  49,  34,  70,  43,]])
        # labels = 21

        print(data.size())
        output, _ = model(data)
        attention_output2 = model.decoder.attention_output
        print(output.size())
        output = output.argmax(axis=-1).view(-1, seq_len)
        # output  append labels
        output = torch.cat((output, torch.tensor(labels).view(-1, 1)), dim=1)
        # print 
        torch.set_printoptions(profile="full")
        print(output)
        #  output 
        for kk in range(9):
            binc = output[:, kk].bincount()
            binc = {i: binc[i].item() for i in range(len(binc)) if binc[i] != 0}
            print(f'kk={kk}')
            print(binc)