from __future__ import print_function
from json import load
import warnings
import argparse
import numpy as np
import torch
import torch.nn as nn
import os
import time
import util.utils as utils
import matplotlib.pyplot as plt


from util.logger import Logger
from tqdm import tqdm
from models.VGG import VGG
from models.MobileNet import MobileNetv1
from optimizer.decayoffSGD import decay_off_SGD
from torch.optim.lr_scheduler import MultiStepLR

warnings.simplefilter("ignore", UserWarning)

class Solver(object):
    def __init__(self, args):
        self.model = None
        self.criterion = None
        self.optimizer = None
        self.train_loader = None
        self.test_loader = None
        self.num_classes = None
        self.optimizer = None
        self.scheduler = None
        self.criterion = None

        self.best_acc = 0.0
        self.best_epoch = 0

        self.cuda = args.cuda
        self.device = "cuda" if args.cuda else "cpu"

        self._train_config_init(args=args)
        self.logger = Logger(args)

    def _train_config_init(self, args):
        self.train_loader, self.test_loader, self.num_classes = utils.get_data_loader(args.dataset, args.batch_size, args.test_batch_size)
        self.model = self._model_init()
        self.optimizer = self._optimizer_init()
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = MultiStepLR(self.optimizer, milestones=args.schedule, gamma=args.gamma)

    def _model_init(self):
        if args.arch == 'VGG':
            model = VGG(activation_type=args.activation_type, num_classes=self.num_classes, depth=args.depth_wide,
                        oper_order=args.operation_order, dataset=args.dataset)
        elif args.arch == 'MobileNetv1':
            model = MobileNetv1(activation_type=args.activation_type, num_classes=self.num_classes,
                                oper_order=args.operation_order, dataset=args.dataset)
        else:
            print('Invalid model architecture')
            exit()

        if self.cuda:
            model.cuda()

        print(model)
        return model

    def _optimizer_init(self):
        decay_off_list = []
        if args.no_gamma_decay:
            decay_off_list.append('gamma')
        if args.no_beta_decay:
            decay_off_list.append('beta')
        if args.no_weight_decay:
            decay_off_list.append('weight')

        optimizer = decay_off_SGD(self.model.parameters(), operation_order=args.operation_order, 
                                        lr=args.lr, momentum=args.momentum,
                                        weight_decay=args.weight_decay, decay_off_list=decay_off_list,
                                        gamma_decay=args.gamma_decay, beta_decay=args.beta_decay)
        optimizer.weight_decay_config_init(self.model)

        # optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        return optimizer

    def train(self, epoch):
        accumulated_batch_loss = 0.0
        self.model.train()
        
        with tqdm(self.train_loader, unit='batchs') as pbar:
            for data, target in pbar:
                pbar.set_description(f"Epoch {epoch}")
                if args.cuda:
                    data, target = data.cuda(), target.cuda()

                self.optimizer.zero_grad()

                output = self.model(data)

                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                
                batch_loss = loss.item()
                pbar.set_postfix(loss=batch_loss)

                accumulated_batch_loss += batch_loss

            self.scheduler.step()

            self.logger.tensor_board('train_loss', {'loss': accumulated_batch_loss}, epoch)

    def test(self, epoch, evaluate=False):
        self.model.eval()
        accumulated_batch_loss = 0

        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = self.model(data)

                batch_loss = self.criterion(output, target)
                accumulated_batch_loss += batch_loss.data

                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).cpu().sum()

            acc = 100. * float(correct) / len(self.test_loader.dataset)
            if acc > self.best_acc or evaluate:
                self.best_acc = acc
                self.best_epoch = epoch

                if not evaluate:
                    self.logger.state_save(self.model, self.best_acc)

            test_acc = 100. * float(correct) / len(self.test_loader.dataset)

            print('\nTest set: Loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
                accumulated_batch_loss, correct, len(self.test_loader.dataset), test_acc))
            print('Best Accuracy: {:.2f}%, Best Epoch: {}\n'.format(self.best_acc, self.best_epoch))

            if not evaluate:
                self.logger.tensor_board('test', {'loss': accumulated_batch_loss}, epoch)
                self.logger.tensor_board('test', {'accuracy': test_acc}, epoch)

    def weight_init(self, decomposed_weight_list):
        for layer in self.model.state_dict():
            decomposed_weight = decomposed_weight_list.pop(0)
            self.model.state_dict()[layer].copy_(decomposed_weight)

    def solve(self):
        if not args.evaluate:
            if args.channel_logging:
                self.logger.channel_saturation_skewness_logging_init(self.model)

            for epoch in range(1, args.epochs + 1):
                if args.channel_logging:
                    self.logger.channel_saturation_skewness_logging(self.model, self.test_loader)

                self.train(epoch)
                self.test(epoch)

            if args.channel_logging:
                self.logger.channel_saturation_skewness_plot()
                
            self.logger.accuracy_save(self.best_acc)
        
        elif args.evaluate and not args.skewness and not args.saturation and not args.empirical_saturation and not args.feature_visualization and not args.activation_distribution:
            self.test(0, evaluate=True)
        else:
            if args.skewness:
                self.logger.skewness(self.model, self.test_loader)
            if args.saturation:
                self.logger.saturation(self.model, self.test_loader, empirical=False)
            if args.empirical_saturation:
                self.logger.saturation(self.model, self.test_loader, empirical=True)
            if args.feature_visualization:
                self.logger.feature_visualization(self.model, self.test_loader)
            if args.activation_distribution:
                self.logger.activation_distribution(self.model, self.test_loader)

            '''
            if args.cosine_similarity:
                self.logger.cosine_similarity(args.diff_class)
            '''

    def main(self):
        if args.pretrained:
            utils.get_pretrained_weight(self.model, args)

        self.model = nn.DataParallel(self.model)

        utils.print_model_parameters(self.model)

        self.solve()
        self.logger.terminate_logging()

if __name__ == '__main__':
    # settings
    parser = argparse.ArgumentParser(description='CTB Model')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                        help='input batch size for testing (default: 256)')
    parser.add_argument('--epochs', type=int, default=200, metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    
    parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)')
    parser.add_argument('--beta-decay', default=None, type=float)
    parser.add_argument('--gamma-decay', default=None, type=float)       

    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--arch', action='store', default='VGG',
                        help='network structure: VGG | MobileNet')
    parser.add_argument('--pretrained', action='store', default=None,
                        help='pretrained model')
    parser.add_argument('--evaluate', action='store_true', default=False,
                        help='whether to run evaluation')
    parser.add_argument('--dataset', action='store', default='cifar10',
                        help='dataset: cifar10 | cifar100 | FashionMNIST | tinyImageNet')
    parser.add_argument('--schedule', type=int, nargs='+', default=[100, 200],
                        help='schedule : (default: [100,200])')
    parser.add_argument('--gamma', type=float, default=0.1,
                        help='gamma : (default: 0.1)')
    parser.add_argument('--bn-momentum', type=float, default=0.1,
                        help='BatchNorm momentum factor')
    parser.add_argument('--depth-wide', action='store', default=None,
                        help='depth and wide (default: None)')                                                
                        
    parser.add_argument('--activation-type', type=str, default="relu",
                        help='activation function type: relu | sigmoid | tanh')
    parser.add_argument('--operation-order', type=str, default='cba',
                        help='set the convolution operation order')
    parser.add_argument('--get-activation', action='store_true', default=False,
                        help='get the activation value')

    parser.add_argument('--get-decompose-activation', action='store_true', default=False,
                        help='get the activation value')

    parser.add_argument('--no-gamma-decay', action='store_true', default=False,
                        help='turn off the weight decay on gamma')
    parser.add_argument('--no-beta-decay', action='store_true', default=False,
                        help='turn off the weight decay on gamma')
    parser.add_argument('--no-weight-decay', action='store_true', default=False,
                        help='turn off the weight decay on convolution and linear')
                        
    parser.add_argument('--xavier-init', action='store_true', default=False,
                        help='Xavier initialization')

    parser.add_argument('--additional-info', default=None, type=str,
                        help='For temp training, add the additional info on logging file name')

    parser.add_argument('--skewness', action='store_true', default=False,
                        help='Calculate the skewness')
    parser.add_argument('--saturation', action='store_true', default=False,
                        help='measuring the degree of saturation')
    parser.add_argument('--empirical-saturation', action='store_true', default=False,
                        help='measuring the degree of saturation')

    parser.add_argument('--activation-distribution', action='store_true', default=False)
    parser.add_argument('--channel-logging', action='store_true', default=False,
                        help='Logging the filter running mean, var')
    parser.add_argument('--feature-visualization', action='store_true', default=False)

    args = parser.parse_args()
    print(args)
    args.cuda = torch.cuda.is_available()
    utils.fix_randomness(args.seed, args.cuda)

    assert args.arch != 'VGG' or (args.arch == 'VGG' and ( (not (args.no_gamma_decay or args.no_beta_decay or args.no_weight_decay)) or 'cifar' in args.dataset))
    assert args.operation_order != 'cbab' or (args.operation_order == 'cbab' and not (args.no_gamma_decay or args.no_beta_decay or args.no_weight_decay))

    if args.beta_decay is None:
        args.beta_decay = args.weight_decay
    if args.gamma_decay is None:
        args.gamma_decay = args.weight_decay
    
    solver = Solver(args)
    solver.main()
