from __future__ import print_function
import warnings
import argparse
import torch
import torch.nn as nn
import util.utils as utils
import torch.distributed as dist

from util.logger import Logger
from tqdm import tqdm
from models.VGG import VGG
from models.MobileNet import MobileNet
from models.PreActResNet import PreActResNet
from optimizer.decayoffSGD import decay_off_SGD
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn import DataParallel
import torch.nn.functional as F

warnings.simplefilter("ignore", UserWarning)

class Solver(object):
    def __init__(self, args):
        self.model = None
        self.criterion = None
        self.optimizer = None
        self.train_loader = None
        self.test_loader = None
        self.num_classes = None
        self.optimizer = None
        self.scheduler = None
        self.criterion = None

        self.best_acc = 0.0
        self.best_epoch = 0

        self.cuda = args.cuda
        self.device = "cuda" if args.cuda else "cpu"

        self._train_config_init(args=args)
        self.logger = Logger(args)

    def _train_config_init(self, args):
        self.train_loader, self.test_loader, self.num_classes = utils.get_data_loader(args.dataset, args.batch_size, args.test_batch_size, args.num_workers)
        self.model = self._model_init()
        self.optimizer = self._optimizer_init()
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = MultiStepLR(self.optimizer, milestones=args.schedule, gamma=args.gamma)

        print(self.model)

    def _model_init(self):
        if args.arch == 'VGG':
            model = VGG(activation_type=args.activation_type, num_classes=self.num_classes, depth=args.depth_wide,
                        oper_order=args.operation_order, dataset=args.dataset, cut_block=args.vgg_cut_block,
                        tau=args.tau)
        elif args.arch == 'MobileNet':
            model = MobileNet(activation_type=args.activation_type, num_classes=self.num_classes,
                                oper_order=args.operation_order, dataset=args.dataset, tau=args.tau)
        elif args.arch == 'PreActResNet':
            model = PreActResNet(activation_type=args.activation_type, depth=int(args.depth_wide), num_classes=self.num_classes, 
            oper_order=args.operation_order, dataset=args.dataset, tau=args.tau)
        else:
            print('Invalid model architecture')
            exit()

        if self.cuda:
            model.cuda()
        
        return model

    def _optimizer_init(self):
        if args.weight_decay == args.gamma_decay and args.weight_decay == args.beta_decay:
            print("Naive SGD")
            optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        else:
            print("Decay off SGD")
            optimizer = decay_off_SGD(self.model.parameters(), operation_order=args.operation_order, 
                                      lr=args.lr, momentum=args.momentum, 
                                      weight_decay=args.weight_decay, gamma_decay=args.gamma_decay, beta_decay=args.beta_decay)
            optimizer.weight_decay_config_init(self.model)

        return optimizer

    def train(self, epoch):
        accumulated_batch_loss = 0.0
        self.model.train()
        
        with tqdm(self.train_loader, unit='batchs') as pbar:
            for data, target in pbar:
                pbar.set_description(f"Epoch {epoch}")

                if args.cuda:
                    data, target = data.cuda(), target.cuda()

                self.optimizer.zero_grad()

                output = self.model(data)

                loss = self.criterion(output, target)
                loss.backward() # self.accelerator.backward(loss) 

                self.optimizer.step()
                
                batch_loss = loss.item()
                pbar.set_postfix(loss=batch_loss)

                accumulated_batch_loss += batch_loss

            self.scheduler.step()

            self.logger.tensor_board('train_loss', {'loss': accumulated_batch_loss}, epoch)

    def test_C(self):
        corruptions = [
        'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
        'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
        'brightness', 'contrast', 'elastic_transform', 'pixelate',
        'jpeg_compression'
        ]
        if args.local_rank != 0:
            return None
        
        corruption_accs = []

        self.model.eval()

        correct = 0
        with torch.no_grad():
            for cor_type in corruptions:
                total_loss = 0
                correct = 0
                
                for data, target in self.test_loader[cor_type]:
                    if self.cuda:
                        data, target = data.cuda(), target.cuda()
                    output = self.model(data)
                    loss = F.cross_entropy(output, target)
                    pred = output.data.max(1)[1]
                    total_loss += float(loss.data)
                    correct += pred.eq(target.data.view_as(pred)).cpu().sum()
                    
                test_loss, test_acc = total_loss / len(self.test_loader[cor_type].dataset), correct / len(self.test_loader[cor_type].dataset)
                corruption_accs.append(test_acc)
                
            print("total acc : ", sum(corruption_accs)/15)

    def test(self, epoch, evaluate=False):
        self.model.eval()
        accumulated_batch_loss = 0

        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = self.model(data)

                batch_loss = self.criterion(output, target)
                accumulated_batch_loss += batch_loss.data

                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).cpu().sum()

            acc = 100. * float(correct) / len(self.test_loader.dataset)
            if acc > self.best_acc or evaluate:
                self.best_acc = acc
                self.best_epoch = epoch

                if not evaluate:
                    self.logger.state_save(self.model, self.best_acc)

            test_acc = 100. * float(correct) / len(self.test_loader.dataset)
            
            print('\nTest set: Loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
                accumulated_batch_loss, correct, len(self.test_loader.dataset), test_acc))
            print('Best Accuracy: {:.2f}%, Best Epoch: {}\n'.format(self.best_acc, self.best_epoch))

            if not evaluate:
                self.logger.tensor_board('test', {'loss': accumulated_batch_loss}, epoch)
                self.logger.tensor_board('test', {'accuracy': test_acc}, epoch)

    def weight_init(self, decomposed_weight_list):
        for layer in self.model.state_dict():
            decomposed_weight = decomposed_weight_list.pop(0)
            self.model.state_dict()[layer].copy_(decomposed_weight)

    def solve(self):
        if not args.evaluate:

            for epoch in range(1, args.epochs + 1):
                self.train(epoch)
                self.test(epoch)

            self.logger.accuracy_save(self.best_acc)
        
        elif args.evaluate and not args.skewness and not args.saturation and not args.sparsity:
            if args.dataset == 'cifar100-c':
                self.test_C(0)
            else:
                self.test(0, evaluate=True)
        
        else:
            if args.skewness:
                self.logger.skewness(self.model, self.test_loader)
            if args.saturation:
                self.logger.saturation(self.model, self.test_loader, empirical=False)
            if args.sparsity:
                self.logger.saturation(self.model, self.test_loader, empirical=True)

    def main(self):
        if args.pretrained:
            utils.get_pretrained_weight(self.model, args)

        self.model = DataParallel(self.model)
        
        utils.print_model_parameters(self.model)

        self.solve()

        self.logger.terminate_logging()

if __name__ == '__main__':
    # settings
    parser = argparse.ArgumentParser(description='CTB Model')
    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test_batch_size', type=int, default=256, metavar='N',
                        help='input batch size for testing (default: 256)')
    parser.add_argument('--epochs', type=int, default=200, metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--log_interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    
    parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)')
    parser.add_argument('--gamma_decay', default=None, type=float)
    parser.add_argument('--beta_decay', default=None, type=float)

    parser.add_argument('--tau', default=-1.0, type=float, help="shifted parameter to generate asymmetry in Tanh")

    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--arch', action='store', default='VGG',
                        help='network structure: VGG | MobileNet | PreActResNet')
    parser.add_argument('--pretrained', action='store', default=None,
                        help='pretrained model')
    parser.add_argument('--evaluate', action='store_true', default=False,
                        help='whether to run evaluation')
    parser.add_argument('--dataset', action='store', default='cifar10',
                        help='dataset: cifar10 | cifar100 | tinyImageNet | cifar100-c | ImageNet')
    parser.add_argument('--schedule', type=int, nargs='+', default=[100, 150],
                        help='schedule : (default: [100, 150])')
    parser.add_argument('--gamma', type=float, default=0.1,
                        help='gamma : (default: 0.1)')
    parser.add_argument('--bn_momentum', type=float, default=0.1,
                        help='BatchNorm momentum factor')
    parser.add_argument('--depth_wide', action='store', default=None,
                        help='depth and wide (default: None)')
    parser.add_argument('--vgg_cut_block', type=int, default=0)
                        
    parser.add_argument('--activation_type', type=str, default="relu",
                        help='activation function type: relu | sigmoid | tanh | shifted_tanh')
    parser.add_argument('--operation_order', type=str, default='cba',
                        help='set the convolution operation order')

    parser.add_argument('--additional_info', default=None, type=str,
                        help='For temp training, add the additional info on logging file name')

    parser.add_argument('--skewness', action='store_true', default=False,
                        help='measuring the degree of asymmetry')
    parser.add_argument('--saturation', action='store_true', default=False,
                        help='measuring the degree of saturation')
    parser.add_argument('--sparsity', action='store_true', default=False,
                        help='measuring the degree of sparsity')
    
    parser.add_argument('--num_workers', type=int, default=4, 
                        help='number of workers for data loading')

    args = parser.parse_args()
    print(args)
    args.cuda = torch.cuda.is_available()
    utils.fix_randomness(args.seed, args.cuda)

    if args.beta_decay is None:
        args.beta_decay = args.weight_decay
    if args.gamma_decay is None:
        args.gamma_decay = args.weight_decay

    solver = Solver(args)
    solver.main()
