import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from Utils import load
from copy import deepcopy
import logging
from torch.optim.swa_utils import AveragedModel, SWALR
from logger import set_logger
from Utils import generator
from Utils import metrics
from train import *
from prune import *

def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    # torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    # Set logger
    set_logger(args.result_dir)
    logger = logging.getLogger()
    logger.info(f"Result folder path: {args.result_dir}")
    
    ## Data ##
    logger.info('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    # prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    prune_loader = None
    trn_loader, val_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    tst_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    
    for l in range(args.start, args.start + 1):       
        if args.soup_collect:
            logger.info(f'=====Start collecting Soup_s{args.seed:03d} for {l}/{args.imp_iter} iteration=====')
            
            models = []
            for i in range(1, args.num_models+1):
                model = load.model(args.model, args.model_class)(input_shape, 
                                                            num_classes, 
                                                            args.dense_classifier,
                                                            args.pretrained).to(device)
                # if args.soup_average == 'sgd':
                #     model.load_state_dict(torch.load("{}/P{}_i{}_s{}_sgd/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.soup_dir, i, l, args.seed, l, args.seed, i), map_location=device))    
                # else:
                model.load_state_dict(torch.load("{}/P{}_i{}_s{}/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.soup_dir, i, l, args.seed, l, int(f'{args.seed}{i}'), i), map_location=device))    

                models.append(model)
            
            swa_model = AveragedModel(models[0])
            for model in models[1:]:
                swa_model.update_parameters(model)
            torch.optim.swa_utils.update_bn(trn_loader, swa_model, device)
            val_metrics, tst_metrics = eval(args, logger, swa_model, loss, val_loader, tst_loader, device, 0, True, is_swa=True)
            with open('{}/results_i{:03d}_p{:03d}.pkl'.format(args.result_dir, l, 0), 'wb') as f:
                pickle.dump(tst_metrics, f, pickle.HIGHEST_PROTOCOL)
            
            model = deepcopy(swa_model.module)
            
            ## Prune Model
            pruner = load.pruner('mag')(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
            sparsity = args.imp_ratio ** (l+1)
            prune_loop(args, model, loss, pruner, prune_loader, device, sparsity, 
                    args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode, args.shuffle, args.invert)
            logger.info(f"Pruning sparsity: {sparsity}")
            
            ## Logging pruning results
            logger.info(f"(Model prune stats)")
            prune_result = metrics.summary(model, 
                                            pruner.scores,
                                            metrics.flop(model, input_shape, device),
                                            lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
            total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
            possible_params = prune_result['size'].sum()
            total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
            possible_flops = prune_result['flops'].sum()
            logger.info("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
            logger.info("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))
            
            torch.save(model.state_dict(),"{}/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.result_dir, l+1, args.seed, 0))

        else:
            logger.info(f'=====Start Soup_p{args.soup_id:03d}_s{args.seed:03d} for {l}/{args.imp_iter} iteration=====')
            ## create matching ticket
            model = load.model(args.model, args.model_class)(input_shape, 
                                                            num_classes, 
                                                            args.dense_classifier,
                                                            args.pretrained).to(device)
            if l == 0:
                if args.matching_epochs != 0: 
                    model.load_state_dict(torch.load("{}/matching_ticket_e{:03d}_s001.pt".format(args.ckpt_dir, args.matching_epochs), map_location=device))    
                    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
                    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1., total_iters=0)
                else:
                    model.load_state_dict(torch.load("{}/matching_ticket_e{:03d}_s{:03d}.pt".format(args.soup_dir, args.matching_epochs, int(args.seed // 10)), map_location=device))    
                ## Save Matching ticket ##
                torch.save(model.state_dict(),"{}/matching_ticket_e{:03d}_s{:03d}_p{:03d}.pt".format(args.result_dir, args.matching_epochs, args.seed, args.soup_id))
                logger.info(f'Saved the matching ticket at epoch {args.matching_epochs}!')
            else:
                model.load_state_dict(torch.load("{}/P{}_i{}_s{}/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.soup_dir, 0, l-1, int(args.seed // 10), l, int(args.seed // 10), 0), map_location=device))    
                if args.imp_type == 'weight':
                    model_dict = model.state_dict()
                    original_dict = torch.load("{}/P{}_i0_s{}/matching_ticket_e{:03d}_s{:03d}_p{:03d}.pt".format(args.soup_dir, args.soup_id, int(args.seed // 10), args.matching_epochs, args.seed, args.soup_id), map_location=device)
                    original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)
            
            ## Train
            if args.soup_train == 'swa':
                logger.info(f"=====SWA training!=====")
                optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                                milestones=[int(args.post_epochs * args.swa_milestone)], 
                                                                gamma=args.lr_drop_rate_swa)
                
                swa_model = AveragedModel(model)
                
                train_swa_eval_loop(args, l-1, logger, model, loss, optimizer, swa_model, scheduler, trn_loader, val_loader, tst_loader, device, args.post_epochs, args.verbose)

                torch.save(swa_model.module.state_dict(),"{}/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.result_dir, l, args.seed, args.soup_id))
                model.load_state_dict(swa_model.module.state_dict())
            elif args.soup_train == 'sgd':
                logger.info(f"=====SGD training!=====")
                optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
                                                                                args.post_epochs, 
                                                                                T_mult=1, 
                                                                                eta_min=args.lr_min)
                model = train_eval_loop(args, l-1, logger, model, loss, optimizer, scheduler, trn_loader, val_loader, tst_loader, device, args.pre_epochs, args.verbose)
                torch.save(model.state_dict(),"{}/ckpt_r0.80_i{:03d}_s{:03d}_p{:03d}.pt".format(args.result_dir, l, args.seed, args.soup_id))
