import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models.model import Model
from tqdm import tqdm
import numpy as np
import torch
from evaluator import Evaluator
from torch.nn import CrossEntropyLoss
from timeit import default_timer as timer


class Trainer(object):
    def __init__(self, params, data_loader, model):
        self.params = params
        self.data_loader = data_loader

        self.evaluator = Evaluator(params, self.data_loader['test'])

        self.model = model.cuda()
        self.criterion = CrossEntropyLoss(label_smoothing=self.params.label_smoothing).cuda()

        if self.params.optimizer == 'AdamW':
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.params.lr,
                                               weight_decay=self.params.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.params.lr, momentum=0.9,
                                             weight_decay=self.params.weight_decay)

        self.data_length = len(self.data_loader['train'])
        self.optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.params.epochs * self.data_length, eta_min=1e-6
        )
        print(self.model)

    def train(self):
        f1_best = 0
        kappa_best = 0
        acc_best = 0
        cm_best = None
        alpha = 1
        mixup = False
        for epoch in range(self.params.epochs):
            self.model.train()
            start_time = timer()
            losses = []
            for x, y in tqdm(self.data_loader['train'], mininterval=10):
                self.optimizer.zero_grad()
                x = x.cuda()
                y = y.cuda()
                if mixup == False:
                    pred = self.model(x)
                    loss = self.criterion(pred, y)
                else:
                    lam = np.random.beta(alpha, alpha)
                    index = torch.randperm(x.size(0))
                    x_a, x_b = x, x[index]
                    y_a, y_b = y, y[index]
                    mixed_x = lam * x_a + (1 - lam) * x_b

                    pred = self.model(mixed_x)
                    loss = lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b)

                loss.backward()
                losses.append(loss.data.cpu().numpy())
                if self.params.clip_value > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_value)
                self.optimizer.step()
                self.optimizer_scheduler.step()

            optim_state = self.optimizer.state_dict()
            with torch.no_grad():
                acc, kappa, f1, cm = self.evaluator.get_accuracy(self.model)
                print(
                    "Epoch {} : Training Loss: {:.5f}, acc: {:.5f}, kappa: {:.5f}, f1: {:.5f}, LR: {:.5f}, Time elapsed {:.2f} mins".format(
                        epoch + 1,
                        np.mean(losses),
                        acc,
                        kappa,
                        f1,
                        optim_state['param_groups'][0]['lr'],
                        (timer() - start_time) / 60
                    )
                )
                print(cm)
                if kappa > kappa_best:
                    print("kappa increasing....saving weights !!")
                    best_f1_epoch = epoch + 1
                    acc_best = acc
                    kappa_best = kappa
                    f1_best = f1
                    cm_best = cm
                    model_path = self.params.model_dir + "/epoch{}_acc_{:.5f}_kappa_{:.5f}_f1_{:.5f}.pth".format(epoch + 1, acc, kappa, f1)
                    torch.save(self.model.state_dict(), model_path)
                    print("model save in " + model_path)


            if epoch + 1 == self.params.epochs:
                print("{} epoch get the best kappa {:.5f}".format(best_f1_epoch, f1_best))
                print("the model is save in " + model_path)
        evaluation_best = np.array([acc_best, kappa_best, f1_best])
        return evaluation_best, cm_best
