#!/usr/bin/env python
# coding: utf-8



import torch
import math
from torch import nn, Tensor
#from datasets import load_dataset
from datetime import datetime
#from sklearn.metrics import f1_score
#from transformers import AutoModel, AutoModelForQuestionAnswering, AutoTokenizer

cuda_id = 1
q_dim = 20
model_set = 'dyn1-L8-v4'

device = torch.device("cuda:"+str(cuda_id) if torch.cuda.is_available() else "cpu")
print(device, q_dim, model_set)

data_size, max_len, w_dim, LR, p_norm, num_labels = 2000, 200, 256, 2e-4, 3, 2
print(data_size, max_len, w_dim, LR, p_norm, num_labels)
input_dataset = torch.randn(data_size, max_len, w_dim).to(device)
label_dataset = torch.argmax(torch.randn(data_size, num_labels, max_len), axis=2).to(device)
print(input_dataset.shape)
print('Dataset Size:', data_size*max_len*w_dim, data_size*max_len)


    
class dyn1_L8_v4(nn.Module):
    def __init__(self, max_len, q_dim, w_dim, p_norm=1):
        super(dyn1_L8_v4, self).__init__() 

        self.q_dim = q_dim
        self.p_norm = p_norm

        self.S_to_Q1_L1 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L1 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L2 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L2 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L3 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L3 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L4 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L4 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L5 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L5 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L6 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L6 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L7 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L7 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_Q1_L8 = nn.Linear(max_len, q_dim, device=device)
        self.S_to_Q2_L8 = nn.Linear(max_len, q_dim, device=device)
        
        self.S_to_F_L12 = nn.Linear(w_dim, num_labels, device=device)
        self.S_to_F_L34 = nn.Linear(w_dim, num_labels, device=device)
        self.S_to_F_L56 = nn.Linear(w_dim, num_labels, device=device)
        self.S_to_F_L78 = nn.Linear(w_dim, num_labels, device=device)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input_vecs):
        
        Q_states1_L1 = self.S_to_Q1_L1(input_vecs.T)
        Q_states2_L1 = self.S_to_Q2_L1(input_vecs.T)
        
        Q_states1_L2 = self.S_to_Q1_L2(input_vecs.T)
        Q_states2_L2 = self.S_to_Q2_L2(input_vecs.T)
        
        Q_states1_L3 = self.S_to_Q1_L3(input_vecs.T)
        Q_states2_L3 = self.S_to_Q2_L3(input_vecs.T)
    
        Q_states1_L4 = self.S_to_Q1_L4(input_vecs.T)
        Q_states2_L4 = self.S_to_Q2_L4(input_vecs.T)
        
        Q_states1_L5 = self.S_to_Q1_L5(input_vecs.T)
        Q_states2_L5 = self.S_to_Q2_L5(input_vecs.T)
    
        Q_states1_L6 = self.S_to_Q1_L6(input_vecs.T)
        Q_states2_L6 = self.S_to_Q2_L6(input_vecs.T)
        
        Q_states1_L7 = self.S_to_Q1_L7(input_vecs.T)
        Q_states2_L7 = self.S_to_Q2_L7(input_vecs.T)
    
        Q_states1_L8 = self.S_to_Q1_L8(input_vecs.T)
        Q_states2_L8 = self.S_to_Q2_L8(input_vecs.T)
        
        relMat_L12 = torch.cdist(Q_states1_L1, Q_states2_L1, p=self.p_norm)-torch.cdist(Q_states1_L2, Q_states2_L2, p=self.p_norm)
        relMat_L34 = torch.cdist(Q_states1_L3, Q_states2_L3, p=self.p_norm)-torch.cdist(Q_states1_L4, Q_states2_L4, p=self.p_norm)
        relMat_L56 = torch.cdist(Q_states1_L5, Q_states2_L5, p=self.p_norm)-torch.cdist(Q_states1_L6, Q_states2_L6, p=self.p_norm)
        relMat_L78 = torch.cdist(Q_states1_L7, Q_states2_L7, p=self.p_norm)-torch.cdist(Q_states1_L8, Q_states2_L8, p=self.p_norm)

        pos_scores_L12 = self.S_to_F_L12(self.sigmoid(input_vecs@relMat_L12))        
        pos_scores_L34 = self.S_to_F_L34(self.sigmoid(input_vecs@relMat_L34))        
        pos_scores_L56 = self.S_to_F_L56(self.sigmoid(input_vecs@relMat_L56))
        pos_scores_L78 = self.S_to_F_L78(self.sigmoid(input_vecs@relMat_L78))
        
        return pos_scores_L12 + pos_scores_L34 + pos_scores_L56 + pos_scores_L78
   
    



ce_func = nn.CrossEntropyLoss()
mse_func = nn.MSELoss()
    

if model_set == 'dyn1-L8-v4': dyn1_model = dyn1_L8_v4(max_len, q_dim, w_dim, p_norm=p_norm)
dyn1_optim = torch.optim.Adam(dyn1_model.parameters(), lr=LR)


total_params = sum(p.numel() for p in dyn1_model.parameters())
print('Model Size:', total_params, data_size*max_len/total_params)


def eval_model(_model, _inputs, _labels):
    # predict the start_id and end_id of the answer from a context given a question
    start_count, end_count = 0, 0
    with torch.no_grad():
        for _id in range(data_size):
            res_scores = dyn1_model(input_dataset[_id])
            if torch.argmax(res_scores[:,0]) == label_dataset[_id][0]: start_count += 1
            if torch.argmax(res_scores[:,1]) == label_dataset[_id][1]: end_count += 1

    return start_count/data_size, end_count/data_size

print(datetime.now().time())
start_count, end_count = eval_model(dyn1_model, input_dataset, label_dataset)
print(datetime.now().time(), start_count, end_count)


max_start, max_end = 0, 0
print(datetime.now().time())
for _ep in range(500):
    cur_loss = 0
    for _id in range(data_size):
        res_scores = dyn1_model(input_dataset[_id])
        loss = ce_func(res_scores[:,0], label_dataset[_id][0])
        loss += ce_func(res_scores[:,1], label_dataset[_id][1])

        dyn1_optim.zero_grad()
        loss.backward()
        dyn1_optim.step()

        cur_loss += loss.item()
        
    start_count, end_count = eval_model(dyn1_model, input_dataset, label_dataset)
    if start_count > max_start: max_start = start_count
    if end_count > max_end: max_end = end_count

    if _ep%5==0: print(datetime.now().time(), _ep, cur_loss, start_count, end_count, max_start, max_end)


