import torch
import numpy as np
from utils import Timer, Maximizer, Averager, Batcher
from sklearn.metrics import roc_auc_score, accuracy_score
from model import KTModel
from torch_scatter import scatter_add
import torch.nn.functional as F
from sklearn.metrics import mean_squared_error

def train(model, seqs, qst_clr, batch_size, criterion, optimizer, device):
    model.train()
    
    batcher = Batcher(seqs, batch_size)
    loss_avg = Averager()

    for batch in batcher:
        batch = torch.tensor(batch, device = device)
        
        qst, rst, msk = batch.permute(2, 0, 1)
        
        cur_qst, cur_rst, cur_msk = qst[:, :-1], rst[:, :-1], msk[:, :-1]
        nxt_qst, nxt_rst, nxt_msk = qst[:, 1:], rst[:, 1:], msk[:, 1:]
        
        pred = model(qst_clr[cur_qst], cur_rst, qst_clr[nxt_qst])
        
        loss = criterion(pred[nxt_msk == 1], nxt_rst[nxt_msk == 1].float())
        loss_avg.join(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss_avg.get()

def evaluate(model, seqs, qst_clr, batch_size, device):
    model.eval()
    
    batcher = Batcher(seqs, batch_size)
    
    pred_list = []
    true_list = []

    for batch in batcher:
        batch = torch.tensor(batch, device = device)

        qst, rst, msk = batch.permute(2, 0, 1)
        
        cur_qst, cur_rst, cur_msk = qst[:, :-1], rst[:, :-1], msk[:, :-1]
        nxt_qst, nxt_rst, nxt_msk = qst[:, 1:], rst[:, 1:], msk[:, 1:]
        
        pred = model(qst_clr[cur_qst], cur_rst, qst_clr[nxt_qst])
        
        pred_list.append(pred[nxt_msk == 1].detach().cpu().numpy())
        true_list.append(nxt_rst[nxt_msk == 1].detach().cpu().numpy())
    
    pred = np.concatenate(pred_list, axis = 0)
    true = np.concatenate(true_list, axis = 0)
    
    auc = roc_auc_score(true, pred)
    acc = accuracy_score(true, (pred > 0.5).astype(float))
    mse = mean_squared_error(true, pred)
    rmse = np.sqrt(mse)

    return auc, acc, rmse

def experiment(
    qst_num, clr_num, qst_clr, hidden_dim, dropout, wd, lr, batch_size, quit_epoch, trn_seqs, evl_seqs, device, 
    evl_trn_seq = False, each_round_fn = None
):
    qst_clr = torch.tensor(qst_clr, device = device)
    
    model = KTModel(clr_num, hidden_dim, dropout).to(device)
    model_params = {'params': model.parameters(), 'weight_decay': wd}

    parameters = [model_params]
    optimizer = torch.optim.Adam(parameters, lr = lr)
    criterion = torch.nn.BCELoss()
        
    epoch = 1
    quit_count = 0

    auc_max = Maximizer()
    max_res = None
    
    dur_avg = Averager()
    
    while quit_count <= quit_epoch:
        timer = Timer()
        loss = timer(train, model, trn_seqs, qst_clr, batch_size, criterion, optimizer, device)
        if evl_trn_seq:
            trn_res = timer(evaluate, model, trn_seqs, qst_clr, batch_size, device)
        evl_res = timer(evaluate, model, evl_seqs, qst_clr, batch_size, device)
                
        if auc_max.join(evl_res[0]):
            max_res = evl_res
            best_model_params = model.state_dict()
            quit_count = 0
        
        if each_round_fn != None:
            if evl_trn_seq:
                each_round_fn(epoch, loss, trn_res, evl_res, max_res, timer.get())
            else:
                each_round_fn(epoch, loss, evl_res, max_res, timer.get())
        
        dur_avg.join(timer.get())
        epoch += 1
        quit_count += 1
    
    model.load_state_dict(best_model_params)
    
    return model, max_res

def comp_entropy(model, seqs, qst_num, clr_num, qst_clr, batch_size, device):
    entropy = torch.zeros((qst_num, clr_num), device = device)
    
    qst_count = torch.zeros((qst_num, ), device = device)
    
    qst_clr = torch.tensor(qst_clr, device = device)
    for batch in Batcher(seqs, batch_size, shuffle = False):
        batch = torch.tensor(batch, device = device)

        qst, rst, msk = batch.permute(2, 0, 1)
        
        cur_qst, cur_rst, cur_msk = qst[:, :-1], rst[:, :-1], msk[:, :-1]
        nxt_qst, nxt_rst, nxt_msk = qst[:, 1:], rst[:, 1:], msk[:, 1:]
        
        pred = model.get_pred(qst_clr[cur_qst], cur_rst).detach()

        qst = nxt_qst.reshape(-1)
        rst = nxt_rst.reshape(-1)
        msk = nxt_msk.reshape(-1)
        pred = pred.reshape(-1, clr_num)

        ety = -F.logsigmoid(pred * torch.sign(rst - 0.5).unsqueeze(-1))
        ety[msk == 0] = 0
        
        qst_one = torch.ones_like(ety[:, 0])
        qst_one[msk == 0] = 0
        qst_count += scatter_add(qst_one, qst, dim = 0, dim_size = qst_num)

        ety = scatter_add(ety, qst, dim = 0, dim_size = qst_num)
        entropy += ety
    
    return entropy

def sinkhorn(entropy, lamb, itr_num = 10, eps = 1e-10):
    C = entropy.clone()
    C = torch.nan_to_num(C, nan = 0.0, posinf = 1e6, neginf = -1e6)
    C = C + 0
    
    R = torch.ones_like(C) / C.shape[1]
    K = torch.exp(-C / lamb) * R

    u = torch.ones(C.shape[0], 1, device = C.device)
    v = torch.ones(C.shape[1], 1, device = C.device)

    for _ in range(itr_num):
        v = 1.0 / (K.T @ u + eps)
        u = 1.0 / (K @ v + eps)
        
    P = K * u * v.T
    
    return P
    