import sys

import argparse
import math
import os
import os.path as osp
import torch # torch >= 1.7.1
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use("Agg")

from lib import get_data,makedirs,accuracy,save_obj,set_seed
from models import MLP
from tqdm import tqdm


class Trainer(nn.Module):
    def __init__(self, args):
        super(Trainer, self).__init__()
        # training setup
        self.width = args.width
        self.depth = args.depth
        self.device = args.device
        self.iterations = args.iterations
        # prepare dataset
        self.train_loader, self.test_loader_eval, self.train_loader_eval, self.num_classes = get_data(args)
        self.batch_size_train = args.batch_size_train
        self.batch_size_eval = args.batch_size_eval

        # prepare models
        if args.model == 'MLP':
            if args.dataset == 'cifar10':
                input_dim = 3 * 32 * 32
            elif args.dataset == 'mnist':
                input_dim = 28 * 28
            else:
                raise NotImplementedError(f"{args.dataset} not supported!")
            self.net = MLP(width=self.width, depth=self.depth, num_classes=self.num_classes,
                           input_dim=input_dim).to(self.device)

            self.model_args = f"{self.width}_{self.depth}_" \
                              f"{self.batch_size_train}_{self.batch_size_eval}_{args.lr}_mom_{args.mom}_" \
                              f"wd_{args.wd}_{self.iterations}"

            if not args.data_augmentation:
                self.model_args += f"_wo-da"
        else:
            raise NotImplementedError(f"{args.model} not supported!")

        print(self.net)
        self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.wd)
        self.criterion = nn.CrossEntropyLoss().to(args.device)
        self.save_freq = args.save_freq
        self.eval_freq = args.eval_freq
        self.save_dir = osp.join(args.save_dir + f"_{args.dataset}_{args.model}", self.model_args)
        self.iter = 0
        self.epoch = 0

        makedirs(osp.join(self.save_dir, "checkpoints"))
        makedirs(osp.join(self.save_dir, "curve"))
        self.generate_plot_dict()

    def generate_plot_dict(self):
        self.plot_dict = {
            "training_loss": [],
            "training_acc": [],
            "eval_test_loss": [],
            "eval_test_acc": [],
            "eval_train_loss": [],
            "eval_train_acc": []
        }

    def run(self):
        while True:
            print(f"EPOCH {self.epoch + 1}")
            self.epoch += 1
            pbar = tqdm(self.train_loader, mininterval=1, ncols=100)
            pbar.set_description("training")
            for i, (image, label) in enumerate(pbar):
                # TODO: save first!
                if self.iter % self.save_freq == 0:
                    self.save()
                pbar.set_description(f"training epoch{self.epoch} iter[{self.iter + 1}/{self.iterations}]")
                self.net.train()
                image, label = image.to(self.device), label.to(self.device)
                self.optimizer.zero_grad()
                out = self.net(image)
                loss = self.criterion(out, label)
                loss.backward()
                acc = accuracy(out, label)
                self.plot_dict["training_loss"].append(loss.item())
                self.plot_dict["training_acc"].append(acc.item())
                pbar.set_postfix_str("loss={:.4f} acc={:.2f}%".format(loss.item(), acc.item()))
                # take the step
                self.optimizer.step()
                if self.iter > self.iterations:
                    # final evaluation and saving results
                    test_acc, test_loss = self.eval(test=True)
                    self.plot_dict["eval_test_loss"].append(test_loss)
                    self.plot_dict["eval_test_acc"].append(test_acc)
                    train_acc, train_loss = self.eval(test=False)
                    self.plot_dict["eval_train_loss"].append(train_loss)
                    self.plot_dict["eval_train_acc"].append(train_acc)
                    torch.save({
                        'model_state_dict': self.net.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                    }, osp.join(self.save_dir, f"model.pt"))
                    # save the logs
                    save_obj(self.plot_dict, osp.join(self.save_dir, "curve", "data.bin"))
                    break

                pbar.set_description(f"eval epoch{self.epoch} iter[{self.iter + 1}/{self.iterations}]")
                if self.iter % args.eval_freq == 0:
                    test_acc, test_loss = self.eval(test=True)
                    self.plot_dict["eval_test_loss"].append(test_loss)
                    self.plot_dict["eval_test_acc"].append(test_acc)
                    train_acc, train_loss = self.eval(test=False)
                    self.plot_dict["eval_train_loss"].append(train_loss)
                    self.plot_dict["eval_train_acc"].append(train_acc)

                self.iter += 1
                self.plot()

            if self.iter > self.iterations:
                break

    def plot(self):
        # plot
        plt.figure(figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.xlabel("iteration")
        plt.plot(list(range(1, len(self.plot_dict["training_loss"]) + 1)), self.plot_dict["training_loss"],
                 label="training loss")
        plt.legend()
        plt.subplot(2, 2, 2)
        plt.xlabel("iteration")
        plt.plot(list(range(1, len(self.plot_dict["training_acc"]) + 1)), self.plot_dict["training_acc"],
                 label="training acc")
        plt.legend()
        plt.subplot(2, 2, 3)
        plt.xlabel("iteration")
        plt.plot(list(range(1, len(self.plot_dict["eval_test_loss"]) * self.eval_freq + 1, self.eval_freq)),
                 self.plot_dict["eval_test_loss"],
                 label="evaluate test loss")
        plt.plot(list(range(1, len(self.plot_dict["eval_train_loss"]) * self.eval_freq + 1, self.eval_freq)),
                 self.plot_dict["eval_train_loss"],
                 label="evaluate train loss")
        plt.legend()
        plt.subplot(2, 2, 4)
        plt.xlabel("iteration")
        plt.plot(list(range(1, len(self.plot_dict["eval_test_acc"]) * self.eval_freq + 1, self.eval_freq)),
                 self.plot_dict["eval_test_acc"],
                 label="evaluate test acc")
        plt.plot(list(range(1, len(self.plot_dict["eval_train_acc"]) * self.eval_freq + 1, self.eval_freq)),
                 self.plot_dict["eval_train_acc"],
                 label="evaluate train acc")
        plt.legend()
        plt.savefig(osp.join(self.save_dir, "curve", "curve.png"), dpi=200)
        plt.close("all")

    def eval(self, test=False):
        self.net.eval()
        total_size = 0
        total_loss = 0
        total_acc = 0
        if test:
            dataloader = self.test_loader_eval
        else:
            dataloader = self.train_loader_eval
        with torch.no_grad():
            for image, label in dataloader:
                image, label = image.to(self.device), label.to(self.device)
                out = self.net(image)
                loss = self.criterion(out, label)
                prec = accuracy(out, label)
                bs = image.size(0)
                total_size += int(bs)
                total_loss += float(loss) * bs
                total_acc += float(prec) * bs
        loss, acc = total_loss / total_size, total_acc / total_size

        return acc, loss

    def save(self):
        # save the models
        torch.save({
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, osp.join(self.save_dir, "checkpoints", f"model_{self.iter}.pt"))
        save_obj(self.plot_dict, osp.join(self.save_dir, "curve", "data.bin"))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--iterations', default=10000, type=int)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_eval', default=100, type=int,
                        help='must be equal to training batch size')
    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--mom', default=0, type=float)
    parser.add_argument('--wd', default=0, type=float)
    parser.add_argument('--save_freq', default=100, type=int)
    parser.add_argument('--dataset', default='cifar10', type=str,
                        help='mnist | cifar10')
    parser.add_argument('--path', default='.', type=str)
    parser.add_argument('--seed', default=2020, type=int)
    parser.add_argument('--model', default='MLP', type=str)
    parser.add_argument('--depth', default=9, type=int)
    parser.add_argument('--width', default=512, type=int,
                        help='width of fully connected layers')
    parser.add_argument('--save_dir', default='results', type=str)
    parser.add_argument('--verbose', action='store_true', default=False)
    parser.add_argument('--double', action='store_true', default=False)
    parser.add_argument('--gpu-id', type=int, default=0)
    parser.add_argument("--data-augmentation", type=bool, default=False)
    parser.add_argument("--eval_freq", default=50, type=int)
    args = parser.parse_args()
    # initial setup
    if args.double:
        torch.set_default_tensor_type('torch.DoubleTensor')
    args.device = torch.device(args.gpu_id)
    set_seed(args.seed)
    print(args)
    trainer = Trainer(args)
    trainer.run()
