#!/usr/bin/env python
# coding: utf-8



import torch
import math
import json
from torch import nn, Tensor
from datetime import datetime
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

cuda_id = 0
device = torch.device("cuda:"+str(cuda_id) if torch.cuda.is_available() else "cpu")
print(device)

with open('SQuAD_resources/SQuAD_valSet.json', 'r') as openfile: SQuAD_valSet = json.load(openfile)
with open('SQuAD_resources/SQuAD_trainSet.json', 'r') as openfile: SQuAD_trainSet = json.load(openfile)

X_len, X_dim, Y_len, Y_dim, p_norm, q_dim, ntokens, LR = 500, 192, 2, 500, 3, 200, 50265, 2e-5
num_D, num_G = 4, 8
#data_size = 20000#len(SQuAD_valSet)
train_size, val_size = 20000, 1000

def generate_label(_dataset, _size):

    label_list = []
    for data_id in range(_size):
        cur_cands = _dataset[data_id]['answers']['text']
        if cur_cands == []: label_list.append(0)
        else: label_list.append(1)
            
    label_dataset = torch.LongTensor(label_list).unsqueeze(1)

    return label_dataset

label_trainSet = generate_label(SQuAD_trainSet, train_size)
label_valSet = generate_label(SQuAD_valSet, val_size)
#label_dataset = torch.argmax(torch.randn(data_size, Y_len, Y_dim), axis=2).to(device)
print(label_trainSet.shape, label_valSet.shape)
print(num_D, num_G)
print(X_len, X_dim, Y_len, Y_dim, p_norm, q_dim, ntokens, LR)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + 0.5*self.pe[:x.size(0)]
        return self.dropout(x)
    


class Q_group(nn.Module):
    def __init__(self, X_len, X_dim, Y_len, Y_dim, q_dim, p_norm, num_D):
        super(Q_group, self).__init__()
        
        self.num_D = num_D
        self.p_norm = p_norm
        self.X_to_Qins = nn.ModuleList([nn.Linear(X_len, q_dim) for _ in range(num_D)])
        self.X_to_Qouts = nn.ModuleList([nn.Linear(X_len, q_dim) for _ in range(num_D)])
        self.X_to_Sin = nn.Linear(X_len, Y_len)
        self.Sout_to_Y = nn.Linear(X_dim, Y_dim)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input_X):
        
        Qins = [X_to_Qin(input_X.T) for X_to_Qin in self.X_to_Qins]
        Qouts = [X_to_Qout(input_X.T) for X_to_Qout in self.X_to_Qouts]
        
        relMat = torch.cdist(Qins[0], Qouts[0], p=self.p_norm)
        for m_id in range(1, self.num_D): relMat = relMat + (-1)**(m_id)*torch.cdist(Qins[m_id], Qouts[m_id], p=self.p_norm)    
        
        Sin = self.X_to_Sin(input_X.T)
        Sout = Sin.T@relMat
        #Sout = self.sigmoid(Sin.T@relMat)
        cur_Y = self.Sout_to_Y(Sout)
        
        return cur_Y
        
        
class dyn1_block(nn.Module):
    def __init__(self, X_len, X_dim, Y_len, Y_dim, p_norm=3, q_dim=20, num_D=4, num_G=8, ntokens=10000):
        super(dyn1_block, self).__init__() 

        self.num_G = num_G
        
        self.PE = PositionalEncoding(X_dim, 0.05, X_len)
        self.embedding = nn.Embedding(ntokens, X_dim)
        
        #self.Q_groups = {'QG_'+str(G_id): Q_group(X_len, X_dim, Y_len, Y_dim, q_dim, p_norm, num_D).to(device) for G_id in range(num_G)}
        #self.total_params = sum([sum([param.numel() for name, param in self.Q_groups['QG_'+str(G_id)].named_parameters()]) for G_id in range(num_G)])
        self.Q_groups = nn.ModuleList()
        for G_id in range(num_G): self.Q_groups.append(Q_group(X_len, X_dim, Y_len, Y_dim, q_dim, p_norm, num_D))
        self.total_params = sum([sum([param.numel() for name, param in self.Q_groups[G_id].named_parameters()]) for G_id in range(num_G)])
        
    def forward_ids(self, input_Xids):
        
        input_X = self.embedding(input_Xids)
        return self.forward_vecs(input_X)
        
    def forward_vecs(self, input_X):
        
        input_X = self.PE(input_X.unsqueeze(0)).squeeze(0)
        
        cur_Y = self.Q_groups[0](input_X)
        for G_id in range(1, self.num_G): cur_Y = cur_Y + self.Q_groups[G_id](input_X)
        
        return cur_Y

    

def process_input(raw_input):

    inputs = tokenizer(raw_input['question'], raw_input['context'], return_tensors="pt")
    _input_ids = torch.LongTensor([0]*X_len)
    _min_len = min(X_len, inputs.input_ids.shape[1])
    _input_ids[:_min_len] = inputs.input_ids.squeeze(0)[:_min_len]
    
    return _input_ids.to(device)


def eval_model(_model, _inputs, _labels, _size):
    
    corr_count, total_count = 0, 0
    with torch.no_grad():
        for data_id in range(_size):
            
            _input_ids = process_input(_inputs[data_id])
            pred = _model.forward_ids(_input_ids)
            corr_count += torch.sum(torch.argmax(pred, axis=1) == _labels[data_id])
            total_count += _labels[data_id].shape[0]
            #if data_id % 1000 == 0: print(datetime.now().time(), corr_count/total_count)
            
    return round((corr_count/total_count).item(), 4)



#SQuAD_trainSet = load_dataset('squad_v2', split='train')
#SQuAD_valSet = load_dataset('squad_v2', split='validation')

#with open('SQuAD_resources/SQuAD_valSet.json', 'r') as openfile: SQuAD_valSet = json.load(openfile)
#with open('SQuAD_resources/SQuAD_trainSet.json', 'r') as openfile: SQuAD_trainSet = json.load(openfile)

#X_len, X_dim, Y_len, Y_dim, p_norm, q_dim, ntokens, LR = 500, 192, 2, 500, 3, 100, 50265, 2e-5
#num_D, num_G = 4, 16
#data_size = len(SQuAD_valSet)
#label_dataset = torch.argmax(torch.randn(data_size, Y_len, Y_dim), axis=2).to(device)
#print(data_size, label_dataset.shape)
#print(num_D, num_G)
#print(X_len, X_dim, Y_len, Y_dim, p_norm, q_dim, ntokens, LR)


dyn1_module = dyn1_block(X_len, X_dim, Y_len, Y_dim, 
                         p_norm=p_norm, q_dim=q_dim, 
                         num_D=num_D, num_G=num_G, 
                         ntokens=ntokens).to(device)

dyn1_optim = torch.optim.Adam(dyn1_module.parameters(), lr=LR)
total_params = dyn1_module.total_params

print('Model Size:', total_params, train_size*(X_len+Y_len)/total_params)


#model_name = "deepset/deberta-v3-large-squad2"
model_name = "deepset/roberta-base-squad2"

#model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

ce_func = nn.CrossEntropyLoss()
mse_func = nn.MSELoss()


print(datetime.now().time())
acc = eval_model(dyn1_module, SQuAD_valSet, label_valSet, val_size)
print(acc)

loss_record = []
acc_record = []
print(datetime.now().time())

for _ep in range(100):
    cur_loss = 0
    for data_id in range(train_size):

        _input_ids = process_input(SQuAD_trainSet[data_id])
        pred = dyn1_module.forward_ids(_input_ids)
        loss = ce_func(pred, label_trainSet[data_id])
        
        dyn1_optim.zero_grad()
        loss.backward()
        dyn1_optim.step()

        cur_loss += loss.item()
        
        #if data_id%(data_size//10) == 0: print(_ep, data_id, datetime.now().time(), cur_loss)
        
    acc = eval_model(dyn1_module, SQuAD_valSet, label_valSet, val_size)
    print(_ep, datetime.now().time(), cur_loss, acc)
    loss_record.append(cur_loss)
    acc_record.append(acc)
    if acc == 1 and cur_loss < 10: break

print(loss_record)
print(acc_record)




