import torch
import  torch.nn as nn
import time
from .utils.performance import f1_score, accuracy, eval_affect
from .utils import MetricsTop, dict_to_str


class MMDL(nn.Module):
    """Implements MMDL classifier."""
    
    def __init__(self, encoders, fusion, head, has_padding=False):
        """Instantiate MMDL Module

        Args:
            encoders (List): List of nn.Module encoders, one per modality.
            fusion (nn.Module): Fusion module
            head (nn.Module): Classifier module
            has_padding (bool, optional): Whether input has padding or not. Defaults to False.
        """
        super(MMDL, self).__init__()
        self.encoders = nn.ModuleList(encoders)
        self.fuse = fusion
        self.head = head
        self.has_padding = has_padding
        self.fuseout = None
        self.reps = []

    def forward(self, inputs):
        """Apply MMDL to Layer Input.

        Args:
            inputs (torch.Tensor): Layer Input

        Returns:
            torch.Tensor: Layer Output
        """
        outs = []
        if self.has_padding:
            for i in range(len(inputs[0])):
                outs.append(self.encoders[i]([inputs[0][i], inputs[1][i]]))
        else:
            for i in range(len(inputs)):
                outs.append(self.encoders[i](inputs[i]))
        self.reps = outs  # modality-specific representations
        if self.has_padding:
            if isinstance(outs[0], torch.Tensor):
                out = self.fuse(outs)
            else:
                out = self.fuse([i[0] for i in outs])
        else:
            out = self.fuse(outs)
        self.fuseout = out
        if type(out) is tuple:
            out = out[0]
        if self.has_padding and not isinstance(outs[0], torch.Tensor):
            return self.head([out, inputs[1][0]])
        return self.head(out)


def deal_with_objective(objective, pred, truth, args):
    """Alter inputs depending on objective function, to deal with different objective arguments."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if type(objective) == nn.CrossEntropyLoss:
        if len(truth.size()) == len(pred.size()):
            truth1 = truth.squeeze(len(pred.size())-1)
        else:
            truth1 = truth
        return objective(pred, truth1.long().to(device))
    elif type(objective) == nn.MSELoss or type(objective) == nn.modules.loss.BCEWithLogitsLoss or type(objective) == nn.L1Loss:
        return objective(pred, truth.float().to(device))
    else:
        return objective(pred, truth, args)
    

def train(
        encoders, fusion, head, train_dataloader, valid_dataloader, total_epochs,
        early_stop=8, patience=7, task="classification", optimtype=torch.optim.RMSprop, lr=0.001, weight_decay=0.0,
        objective=nn.CrossEntropyLoss(), save='best.pt', validtime=False, objective_args_dict=None, input_to_float=True, clip_val=8, num_modal=3):
    """
    Handle running a simple supervised training loop.
    
    :param encoders: list of modules, unimodal encoders for each input modality in the order of the modality input data.
    :param fusion: fusion module, takes in outputs of encoders in a list and outputs fused representation
    :param head: classification or prediction head, takes in output of fusion module and outputs the classification or prediction results that will be sent to the objective function for loss calculation
    :param total_epochs: maximum number of epochs to train
    :param early_stop: whether to stop early if valid performance does not improve over 8 epochs
    :patience (int, optional): Adjust the learnig rate if valid performance does not improve for this many epochs. Defaults to 7.
    :param task: type of task, currently support "classification","regression","multilabel"
    :param optimtype: type of optimizer to use
    :param lr: learning rate
    :param weight_decay: weight decay of optimizer
    :param objective: objective function, which is either one of CrossEntropyLoss, MSELoss or BCEWithLogitsLoss or a custom objective function that takes in three arguments: prediction, ground truth, and an argument dictionary.
    :param save: the name of the saved file for the model with current best validation performance
    :param validtime: whether to show valid time in seconds or not
    :param objective_args_dict: the argument dictionary to be passed into objective function. If not None, at every batch the dict's "reps", "fused", "inputs", "training" fields will be updated to the batch's encoder outputs, fusion module output, input tensors, and boolean of whether this is training or validation, respectively.
    :param input_to_float: whether to convert input to float type or not
    :param clip_val: grad clipping limit
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MMDL(encoders, fusion, head).to(device)

    op = optimtype([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    op, 
    mode='min',           
    factor=0.5,          
    patience=patience,   
    verbose=True,        
    min_lr=1e-5)
            
    bestvalloss = 10000
    bestacc = 0
    bestf1 = 0
    patience_count = 0

    def _processinput(inp):
            if input_to_float:
                return inp.float()
            else:
                return inp
            
    for epoch in range(total_epochs):
        totalloss = 0.0
        totals = 0
        model.train()
        
        for j in train_dataloader:
            op.zero_grad()
            model.train()
            if num_modal == 3:
                out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-3:]])
            else:
                out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-2:]])
                
            if not (objective_args_dict is None):
                objective_args_dict['reps'] = model.reps
                objective_args_dict['fused'] = model.fuseout
                objective_args_dict['inputs'] = j[:-1]
                objective_args_dict['training'] = True
                objective_args_dict['model'] = model
            labels = j[list(j.keys())[-4]]['M'] if num_modal == 3 else j[list(j.keys())[-3]]['M']
            loss = deal_with_objective(
            objective, out, labels, objective_args_dict)

            totalloss += loss * len(labels)
            totals += len(labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val)
            op.step()
        print("Epoch "+str(epoch)+" train loss: "+str(totalloss/totals))
        
        validstarttime = time.time()
        if validtime:
            print("train total: "+str(totals))
        
        model.eval()
        with torch.no_grad():
            totalloss = 0.0
            pred = []
            true = []

            for j in valid_dataloader:
                if num_modal == 3:
                    out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-3:]])
                else:
                    out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-2:]])

                if not (objective_args_dict is None):
                        objective_args_dict['reps'] = model.reps
                        objective_args_dict['fused'] = model.fuseout
                        objective_args_dict['inputs'] = j[:-1]
                        objective_args_dict['training'] = False
                labels = j[list(j.keys())[-4]]['M'] if num_modal == 3 else j[list(j.keys())[-3]]['M']
                loss = deal_with_objective(
                    objective, out, labels, objective_args_dict)
                totalloss += loss*len(labels)
                
                if task == "classification":
                    pred.append(torch.argmax(out, 1))
                elif task == "multilabel":
                    pred.append(torch.sigmoid(out).round())
                true.append(labels)

        if pred:
                pred = torch.cat(pred, 0)
        true = torch.cat(true, 0)
        totals = true.shape[0]
        valloss = totalloss/totals
        scheduler.step(valloss)
        if task == "classification":
            acc = accuracy(true, pred)
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss) +
                    " acc: "+str(acc))
            if acc > bestacc:
                patience_count = 0
                bestacc = acc
                print("Saving Best")
                torch.save(model, save)
            else:
                patience_count += 1
        elif task == "multilabel":
            f1_micro = f1_score(true, pred, average="micro")
            f1_macro = f1_score(true, pred, average="macro")
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss) +
                    " f1_micro: "+str(f1_micro)+" f1_macro: "+str(f1_macro))
            if f1_macro > bestf1:
                patience_count = 0
                bestf1 = f1_macro
                print("Saving Best")
                torch.save(model, save)
            else:
                patience_count += 1
        elif task == "regression":
            print("Epoch "+str(epoch)+" valid loss: "+str(valloss.item()))
            if valloss < bestvalloss:
                patience_count = 0
                bestvalloss = valloss
                print("Saving Best")
                torch.save(model, save)
            else:
                patience_count += 1
        if early_stop and patience_count > early_stop:
            break
        validendtime = time.time()
        if validtime:
            print("valid time:  "+str(validendtime-validstarttime))
            print("Valid total: "+str(totals))


def test(model, test_dataloader, criterion=nn.CrossEntropyLoss(), task="classification", dataset="mosi", input_to_float=True, num_modal=3):
    """Run test for multmodal model.

    Args:
        model (nn.Module): Model to test
        test_dataloader (torch.utils.data.Dataloader): Test dataloader
        criterion (_type_, optional): Loss function. Defaults to nn.CrossEntropyLoss().
        task (str, optional): Task to evaluate. Choose between "classification", "multiclass", "regression", "posneg-classification". Defaults to "classification".
        input_to_float (bool, optional): Whether to convert inputs to float before processing. Defaults to True.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    def _processinput(inp):
        if input_to_float:
            return inp.float()
        else:
            return inp
    
    metrics = MetricsTop(task).getMetics(dataset)
        
    with torch.no_grad():
        totalloss = 0.0
        pred = []
        true = []
        for j in test_dataloader:
            model.eval()
            if num_modal == 3:
                out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-3:]])
            else:
                out = model([_processinput(j[i]).to(device) for i in list(j.keys())[-2:]])

            labels = j[list(j.keys())[-4]]['M'] if num_modal == 3 else j[list(j.keys())[-3]]['M']
            if type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss or type(criterion) == torch.nn.MSELoss:
                loss = criterion(out, labels.float().to(device))
            elif type(criterion) == nn.CrossEntropyLoss:
                if len(labels.size()) == len(out.size()):
                    truth1 = labels.squeeze(len(out.size())-1)
                else:
                    truth1 = labels
                loss = criterion(out, truth1.long().to(device))
            else:
                loss = criterion(out, labels.to(device))
            totalloss += loss*len(labels)
            if task == "classification":
                pred.append(torch.argmax(out, 1))
            elif task == "multilabel":
                pred.append(torch.sigmoid(out).round())
            elif task == "posneg-classification":
                prede = []
                oute = out.cpu().numpy().tolist()
                for i in oute:
                    if i[0] > 0:
                        prede.append(1)
                    elif i[0] < 0:
                        prede.append(-1)
                    else:
                        prede.append(0)
                pred.append(torch.LongTensor(prede))
            else: 
                pred.append(out)
                
            true.append(labels)

        if pred:
            pred = torch.cat(pred, 0)
        true = torch.cat(true, 0)
        totals = true.shape[0]
        testloss = totalloss/totals
        if task == "classification":
            print("acc: "+str(accuracy(true, pred)))
            return {'Accuracy': accuracy(true, pred)}
        elif task == "multilabel":
            print(" f1_micro: "+str(f1_score(true, pred, average="micro")) +
                  " f1_macro: "+str(f1_score(true, pred, average="macro")))
            return {'micro': f1_score(true, pred, average="micro"), 'macro': f1_score(true, pred, average="macro")}
        elif task == "regression":
            test_results = metrics(true, pred)
            print(f"Regression test results: {dict_to_str(test_results)}")
            return {'Loss': testloss.item(), **test_results}
        elif task == "posneg-classification":
            trueposneg = true
            accs = eval_affect(trueposneg, pred)
            acc2 = eval_affect(trueposneg, pred, exclude_zero=False)
            print("acc: "+str(accs) + ', ' + str(acc2))
            return {'Accuracy': accs}
