import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_loader import data_loaders
from lls import LearnableLabelSmoothing
from models import models_small, models_big
from sklearn.metrics import accuracy_score, confusion_matrix


class Solver(object):
    def __init__(self, args):
        self.args = args
        self.train_loader, self.test_loader = data_loaders(self.args)

        if self.args.image_size < 224:
            self.net = getattr(models_small, self.args.model)(n_classes=self.args.n_classes, n_channels=self.args.n_channels).cuda()
        else:
            self.net = getattr(models_big, self.args.model)(n_classes=self.args.n_classes).cuda()

        self.ce_loss = nn.CrossEntropyLoss()

        if self.args.method == 'lls':
            self.loss_fn = LearnableLabelSmoothing(self.args.n_classes, self.args.margin).cuda()
        else:
            self.loss_fn = nn.CrossEntropyLoss()

    def train(self):
        iter_per_epoch = len(self.train_loader)
        print(f"Iters per epoch: {iter_per_epoch}")

        optimizer = torch.optim.SGD(list(trainable_params) + list(self.loss_fn.parameters()), self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay)
        
        # schedulers for linear warmup of lr and then decay
        linear_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/self.args.warmup, end_factor=1.0, total_iters=self.args.warmup-1)
        step_decay = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.lr_drop_epochs, gamma=self.args.lr_drop)

        best_acc = 0
        for epoch in range(self.args.epochs):
            self.net.train()

            if epoch > 0:
                step_decay.step()
            if epoch > 0 and epoch < self.args.warmup:
                linear_warmup.step()
            print(f"\nEp:[{epoch + 1}/{self.args.epochs}]\tlr:{optimizer.param_groups[0]['lr']:.6f}")

            for i, (x, y) in enumerate(self.train_loader):
                x, y = x.cuda(), y.cuda()

                logits = self.net(x)
                loss = self.loss_fn(logits, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % 50 == 0 or i == (iter_per_epoch - 1):
                    print(f'It: {i + 1}/{iter_per_epoch}\tloss:{loss.item():.4f}')

            torch.save(self.net.state_dict(), os.path.join(self.args.model_path, self.args.model_name))
            if self.args.q_matrix:
                torch.save(self.loss_fn.state_dict(), os.path.join(self.args.model_path, "q_matrix.pt"))

            test_acc = self.test(train=(epoch + 1) % 25 == 0)
            best_acc = max(best_acc, test_acc)
            print(f"Best test acc: {best_acc:.2%}")

    def compute_test_metric(self, loader):
        self.net.eval()

        actual = []
        all_logits = []

        for data in loader:
            x, y = data
            x = x.cuda()

            with torch.no_grad():
                logits = self.net(x)

            actual.append(y)
            all_logits.append(logits.cpu())

        actual = torch.cat(actual)
        all_logits = torch.cat(all_logits)
        predictions = torch.max(all_logits, dim=-1)[1]

        acc = accuracy_score(y_true=actual, y_pred=predictions)
        loss = self.ce_loss(all_logits, actual)
        cm = confusion_matrix(y_true=actual, y_pred=predictions, labels=range(self.args.n_classes))

        return acc, loss, cm

    def test(self, train=False):
        if train:
            acc, loss, cm = self.compute_test_metric(self.train_loader)
            print(f"Train Accuracy: {acc:.2%}\tLoss: {loss:.2f}")
            if self.args.cm:
                print(cm)

        acc, loss, cm = self.compute_test_metric(self.test_loader)
        print(f"Test Accuracy: {acc:.2%}\tLoss: {loss:.2f}")
        if self.args.cm:
            print(cm)

        if self.args.q_matrix and self.args.cm:
            q_matrix = F.softmax(self.loss_fn.q_matrix, -1)
            print("\nQ-Matrix:")
            print(q_matrix.detach().cpu().numpy().round(2))

        return acc
