import subprocess as sp
import time
from os.path import join

import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm


class Trainer:
    def __init__(
        self,
        config,
        model,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        scheduler=None,
        val_set=None,
    ):
        self.model = model
        self.logger = logger
        self.train_set = train_set
        self.test_set = test_set
        if val_set:
            self.val_set = val_set
        else:
            self.val_set = None
        self.config = config
        self.cur_iter = 0
        self.batch_size = self.config["train"]["batch_size"]
        self.num_workers = self.config["train"]["num_workers"]
        self.criterion = criterion
        self.optimizer = optimizer
        self.global_iter = 0
        # self.cur_epoch = 0
        self.epoch = self.config["train"]["epochs"]
        self.log_interval = self.config["train"]["print_every"]
        self.save_every_epoch = config["train"]["save_every_epoch"]
        # TODO: identify the GPU device in the config
        # TODO: Add evaluate every epoch
        if torch.cuda.device_count() == 1:
            self.model = self.model.cuda()
            self.device = torch.device("cuda:0")
        elif torch.cuda.device_count() > 1:
            self.model = DataParallel(self.model).cuda()
            self.device = torch.device("cuda")

        self.scheduler = scheduler
        self.get_dataloaders()

    def get_gpu_memory(self):
        command = "nvidia-smi --query-gpu=memory.free --format=csv"
        memory_free_info = (
            sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
        )
        memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
        return memory_free_values

    def save_model(self, file_name):
        if ".pth" not in file_name or ".pt" not in file_name:
            file_name += ".pth"
        torch.save(
            self.model.state_dict(),
            join(self.config["general"]["save_model_dir"], file_name),
        )
        print(f"Model saved as {file_name}")

    def save_best_model(self):
        self.save_model("best")

    def save_last_model(self):
        self.save_model("last")

    def accuracy(self, logit, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        output = F.softmax(logit, dim=1)
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group["lr"]

    def prepare_data(self, data):
        if len(data) == 2:
            inputs, labels = data
            attributes = None
            idx = None
        elif len(data) == 3:
            inputs, labels, idx = data
            attributes = None
        elif len(data) == 4:
            inputs, labels, attributes, idx = data
            attributes = attributes.to(self.device, non_blocking=True)
            attributes = attributes.long()
        inputs = inputs.to(self.device, non_blocking=True).float()
        labels = labels.to(self.device, non_blocking=True).long()
        if len(labels.shape) > 1:
            labels = torch.squeeze(labels)
        return inputs, labels, attributes, idx

    def run(self):
        print("==> Start training..")
        best_acc = 0.0
        for cur_epoch in range(self.epoch):
            self.model.train()
            epoch_loss, epoch_correct, total_num = 0.0, 0.0, 0.0
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for data in tepoch:
                    tepoch.set_description(f"Epoch {cur_epoch}")
                    inputs, labels, attributes, idx = self.prepare_data(data)
                    self.optimizer.zero_grad()
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                    loss.backward()
                    self.optimizer.step()
                    correct = (outputs.argmax(1) == labels).sum().item()
                    tepoch.set_postfix(
                        loss=loss.item(),
                        accuracy=100.0 * correct / inputs.size(0),
                        lr=self.get_lr(),
                    )
                    epoch_loss += loss
                    epoch_correct += correct
                    total_num += inputs.size(0)
                    self.global_iter += 1
                    if (
                        self.global_iter % self.config["general"]["logger"]["frequency"]
                        == 0
                    ):
                        self.logger.info(
                            f"[{cur_epoch}]/[{self.epoch}], Global Iter: {self.global_iter}, Loss: {loss:.4f}, Acc: {100.0 * correct / inputs.size(0):.4f}, lr: {self.get_lr():.6f}",
                            {
                                "cur_epoch": cur_epoch,
                                "iter": self.global_iter,
                                "loss": loss.item(),
                                "Accuracy": 100.0 * correct / inputs.size(0),
                                "lr": self.get_lr(),
                            },
                        )
            epoch_loss /= total_num
            epoch_acc = epoch_correct / total_num * 100.0
            if self.val_set:
                _ = self.evaluate(val=True)
            test_acc = self.evaluate(val=False)

            if test_acc > best_acc:
                best_acc = test_acc
                self.save_best_model()
            print(
                f"Epoch: {cur_epoch}, Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}"
            )
            self.logger.info(
                f"[{cur_epoch}]/[{self.epoch}], Loss: {epoch_loss:.6f}, Train Acc: {epoch_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test Acc: {best_acc:.4f}",
                {
                    "test_epoch": cur_epoch,
                    "loss": epoch_loss.item(),
                    "Train Acc": epoch_acc,
                    "Test Acc": test_acc,
                    "Best Test Acc": best_acc,
                },
            )

            if self.scheduler:
                self.scheduler.step()
            self.save_last_model()

            if cur_epoch % self.save_every_epoch == 0:
                self.save_model(f"{cur_epoch}")

    def evaluate(self, val=True, second_model=False):
        if second_model:
            try:
                model_test = self.model2
            except Exception as e:
                print("There is no second model. Still testing the first model.")
                model_test = self.model
        else:
            model_test = self.model
        model_test.eval()
        correct, total_num, total_loss = 0.0, 0.0, 0.0
        loader = self.val_loader if val else self.test_loader
        evaluate_type = "Val" if val else "Test"
        for (
            iter,
            data,
        ) in enumerate(loader):
            inputs, labels, attributes, idx = self.prepare_data(data)
            with torch.no_grad():
                outputs = model_test(inputs)
            total_loss += self.criterion(outputs, labels).item()
            outputs = outputs.detach().cpu()
            labels = labels.detach().cpu()
            correct += (outputs.argmax(1) == labels).sum().item()
            total_num += labels.size(0)
        acc = correct / total_num * 100
        print(f"{evaluate_type} Acc: {acc:.4f}")
        return acc

    def get_dataloaders(self):
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        self.test_loader = DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )
        if self.val_set:
            self.val_loader = DataLoader(
                self.val_set,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True,
            )
        else:
            self.val_loader = None
