"""Implements training pipeline for unimodal comparison."""
from sklearn.metrics import accuracy_score, f1_score
import torch
from torch import nn
from tqdm import tqdm
from .utils.performance import eval_affect, eval_regression
from .utils import dict_to_str
softmax = nn.Softmax()


def train(encoder, head, train_dataloader, valid_dataloader, total_epochs, early_stop=False, patience=7, optimtype=torch.optim.RMSprop, lr=0.001, weight_decay=0.0, clip_value=1.0, criterion=nn.CrossEntropyLoss(), save_encoder='encoder.pt', save_head='head.pt', task='classification', device='cuda'):
    """Train unimodal module.

    Args:
        encoder (nn.Module): Unimodal encodder for the modality
        head (nn.Module): Takes in the unimodal encoder output and produces the final prediction.
        train_dataloader (torch.utils.data.DataLoader): Training data dataloader
        valid_dataloader (torch.utils.data.DataLoader): Validation set dataloader
        total_epochs (int): Total number of epochs
        early_stop (bool, optional): Whether to apply early-stopping or not. Defaults to False.
        patience (int, optional): Training stops if validation loss does not improve for this many epochs. Defaults to 7.
        optimtype (torch.optim.Optimizer, optional): Type of optimizer to use. Defaults to torch.optim.RMSprop.
        lr (float, optional): Learning rate. Defaults to 0.001.
        weight_decay (float, optional): Weight decay of optimizer. Defaults to 0.0.
        clip_value: Clips gradient norm of model parameters. Defaults to 1.0.
        criterion (nn.Module, optional): Loss module. Defaults to nn.CrossEntropyLoss().
        save_encoder (str, optional): Path of file to save model with best validation performance. Defaults to 'encoder.pt'.
        save_head (str, optional): Path fo file to save head with best validation performance. Defaults to 'head.pt'.
        task (str, optional): Type of task to try. Supports "classification", "regression", or "multilabel". Defaults to 'classification'.
        device (str): Device type to use ('cuda' or 'cpu').
    """
    model = nn.Sequential(encoder, head)
    device = torch.device(device)

    model = model.to(device)
    criterion = criterion.to(device)

    op = optimtype(model.parameters(), lr=lr, weight_decay=weight_decay)
    bestvalloss = 10000
    bestacc = 0
    bestf1 = 0
    patience_count = 0
    for epoch in range(total_epochs):
        totalloss = 0.0
        totals = 0
        for j in train_dataloader:
            op.zero_grad()
            #print(j.keys())
            out = model(j[list(j.keys())[-1]].float().to(device))
            
            if type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss:
                loss = criterion(out, j[list(j.keys())[-2]]['M'].float().to(device))
            else:
                loss = criterion(out, j[list(j.keys())[-2]]['M'].to(device))
            totalloss += loss * len(j[list(j.keys())[-2]]['M'])
            totals += len(j[list(j.keys())[-2]]['M'])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            op.step()
        print("Epoch "+str(epoch)+" train loss: "+str(totalloss/totals))
        with torch.no_grad():
            totalloss = 0.0
            pred = []
            true = []
            for j in valid_dataloader:
                out = model(j[list(j.keys())[-1]].float().to(device))
                if type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss:
                    loss = criterion(out, j[list(j.keys())[-2]]['M'].float().to(device))
                else:
                    loss = criterion(out, j[list(j.keys())[-2]]['M'].to(device))
                totalloss += loss*len(j[list(j.keys())[-2]]['M'])
                if task == "classification":
                    pred.append(torch.argmax(out, 1))
                elif task == "multilabel":
                    pred.append(torch.sigmoid(out).round())
                true.append(j[list(j.keys())[-2]]['M'])
        if pred:
            pred = torch.cat(pred, 0).cpu().numpy()
        true = torch.cat(true, 0).cpu().numpy()
        totals = true.shape[0]
        valloss = totalloss/totals
        if task == "classification":
            acc = accuracy_score(true, pred)
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss) +
                  " acc: "+str(acc))
            if acc > bestacc:
                patience_count = 0
                bestacc = acc
                print("Saving Best")
                torch.save(encoder, save_encoder)
                torch.save(head, save_head)
            else:
                patience_count += 1
        elif task == "multilabel":
            f1_micro = f1_score(true, pred, average="micro")
            f1_macro = f1_score(true, pred, average="macro")
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss) +
                  " f1_micro: "+str(f1_micro)+" f1_macro: "+str(f1_macro))
            if f1_macro > bestf1:
                patience_count = 0
                bestf1 = f1_macro
                print("Saving Best")
                torch.save(encoder, save_encoder)
                torch.save(head, save_head)
            else:
                patience_count += 1
        elif task == "regression":
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss))
            if valloss < bestvalloss:
                patience_count = 0
                bestvalloss = valloss
                print("Saving Best")
                torch.save(encoder, save_encoder)
                torch.save(head, save_head)
            else:
                patience_count += 1
        if early_stop and patience_count > patience:
            break


def test(encoder, head, test_dataloader, task='classification', criterion=None, device='cuda', dataset='mosi'):
    """Test unimodal model on one dataloader.

    Args:
        encoder (nn.Module): Unimodal encoder module
        head (nn.Module): Module which takes in encoded unimodal input and predicts output.
        test_dataloader (torch.utils.data.DataLoader): Data Loader for test set.
        task (str, optional): Type of task to try. Supports "classification", "regression", or "multilabel". Defaults to 'classification'.
        criterion (nn.Module, optional): Loss module. Defaults to None.
        device (str): Device type to use ('cuda' or 'cpu').
    Returns:
        dict: Dictionary of (metric, value) relations.
    """
    model = nn.Sequential(encoder, head)

    with torch.no_grad():
        pred = []
        true = []
        totalloss = 0
        for j in test_dataloader:
            out = model(j[list(j.keys())[-1]].float().to(device))
            #print(out)
            if criterion is not None:
                loss = criterion(out, j[list(j.keys())[-2]]['M'].to(device))
                totalloss += loss*len(j[list(j.keys())[-2]]['M'])
            if task == "classification":
                pred.append(torch.argmax(out, 1))
            elif task == "multilabel":
                pred.append(torch.sigmoid(out).round())
            elif task == "posneg-classification":
                prede = []
                oute = out.cpu().numpy().tolist()
                for i in oute:
                    if i[0] > 0:
                        prede.append(1)
                    elif i[0] < 0:
                        prede.append(-1)
                    else:
                        prede.append(0)
                pred.append(torch.LongTensor(prede))
            else: 
                pred.append(out)
            
            true.append(j[list(j.keys())[-2]]['M'])
            
        if pred:
            pred = torch.cat(pred, 0).cpu().numpy()
        true = torch.cat(true, 0).cpu().numpy()
        totals = true.shape[0]
        if criterion is not None:
            print("loss: " + str(totalloss / totals)) 
        if task == "classification":
            f1_micro = f1_score(true, pred, average="micro")
            f1_macro = f1_score(true, pred, average="macro")
            print("acc: "+str(accuracy_score(true, pred)))
            print("f1_micro: " + str(f1_micro))
            print("f1_macro: " + str(f1_macro))
            return {'Accuracy': accuracy_score(true, pred)}
        elif task == "multilabel":
            print(" f1_micro: "+str(f1_score(true, pred, average="micro")) +
                  " f1_macro: "+str(f1_score(true, pred, average="macro")))
            return {'F1 score (micro)': f1_score(true, pred, average="micro"), 'F1 score (macro)': f1_score(true, pred, average="macro")}
        elif task == "posneg-classification":
            trueposneg = true
            print(trueposneg)
            print(pred)
            accs = eval_affect(trueposneg, pred) # ignore neutral samples
            acc2 = eval_affect(trueposneg, pred, exclude_zero=False)
            print("acc: "+str(accs))  + ', ' + str(acc2)
            return {'Accuracy': accs}
        else:
            test_results = eval_regression(true, pred)
            print(f"Regression test results: {dict_to_str(test_results)}")
            return {'MAE': (totalloss / totals).item(), **test_results}
