import torch.nn as nn
import torch
import numpy as np

from .utility import *
from .metric import *
from tqdm import tqdm

class MASH_Trainer:
    def __init__(self, args, network, optimizer, train_loader, valid_loader, test_loader):
        self.args = args
        self.network = network
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader        
        self.checkpoint_path = "./checkpoints/" + str(args.model) + "/"

    def t_loss(self):
        alpha = 1
        n_scale = self.network.n_scale
        t_loss_final_ = []
        for i in range(n_scale):
            t_abs = torch.abs(self.network.scales[i])
            t_zero = torch.zeros_like(self.network.scales[i])

            t_loss_tmp = torch.where(self.network.scales[i] < 1e-2, t_abs, t_zero)
            t_loss = torch.sum(t_loss_tmp) # λ|s|
            t_loss_final_.append(t_loss)
        
        t_loss_final = torch.mean(torch.stack(t_loss_final_))
        
        return alpha * t_loss_final

    def train(self, i):
        acc_best = 0.0

        pbar = tqdm(range(1, self.args.epoch + 1))
        pbar.set_description(f'FOLD-{i}')
        for epoch in pbar:
            self.network.train()

            for A, X, Y, EIGVEC, EIGVAL in self.train_loader:
                self.optimizer.zero_grad() # Sets the gradients of all optimized tensors to zero
                Y_hat = self.network(X, EIGVEC, EIGVAL) # (sample, label)
                loss_train = self.criterion(Y_hat, Y) + self.t_loss()
                acc_train = macro_accuracy(Y.detach().cpu(), Y_hat.max(1)[1].detach().cpu())
                loss_train.backward() # Computes the gradient of current tensor w.r.t. graph leaves
                self.optimizer.step() # Updates the parameters

                pbar.set_postfix({'loss': loss_train.item(), \
                                'acc': acc_train.item(), \
                                's': torch.stack(list(self.network.scales)).detach().cpu().tolist(), \
                                })
            
                torch.save(self.network.state_dict(), 
                        self.checkpoint_path + 
                        str(self.args.data) + '_' +
                        str(self.args.feat) + '_' +
                        str(self.args.lab) + '_' +
                        str(i) + '_' + 
                        str(self.network.n_scale) + '_' +
                        str(self.network.n_hyper) + '_' +
                        str(self.network.K) + '_' +
                        str(self.network.d_hyper) + '_' +
                        'best.pth')

    def test(self, i):
        tac, tpr, tsp, tse, f1s = [[] for _ in range(5)]
        
        self.network.load_state_dict(torch.load(self.checkpoint_path + 
                                    str(self.args.data) + '_' +
                                    str(self.args.feat) + '_' + 
                                    str(self.args.lab) + '_' + 
                                    str(i) + '_' + 
                                    str(self.network.n_scale) + '_' +
                                    str(self.network.n_hyper) + '_' +
                                    str(self.network.K) + '_' +
                                    str(self.network.d_hyper) + '_' +
                                    'best.pth'))
            
        self.network.eval()
        for A, X, Y, EIGVEC, EIGVAL in self.test_loader:
            Y_hat = self.network(X, EIGVEC, EIGVAL) 
            ac, pr, sp, se, f1 = classification_metrics(Y_hat, Y)
            
            print(f"Acc: {ac:.3f} Pre: {pr:.3f} Rec: {se:.3f} Spe: {sp:.3f} F1s: {f1:.3f}")

            tac.append(ac.item())
            tpr.append(pr.item())
            tsp.append(sp.item())
            tse.append(se.item())
            f1s.append(f1.item())
            
        return np.mean(tac), np.mean(tpr), np.mean(tsp), np.mean(tse), np.mean(f1s)