import os
import numpy as np

import torch
import torch.nn as nn
from Utils import load
from Utils import generator
from Utils import metrics
from Utils.logger import *
from train import *
from prune import *

def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    if args.model == 'fc':
        model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained,
                                                     L=args.n_layers,
                                                     N=args.hidden_dim,
                                                     nonlinearity=nn.ReLU()).to(device)
    
    else:
        model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)
    
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    try:
        if args.optimizer in ['ammd', 'hmd', 'ahmd']:
            opt_kwargs['delta'] = args.delta
    except:
        pass
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)


    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                 test_loader, device, args.pre_epochs, args.verbose, args)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner, args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
               args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode, args.shuffle, args.invert)

    
    ## Save Results and Model ##
    # if args.save:

    print('Saving results.')
    mask_dict = {}
    for n, m in model.named_buffers():
        if 'running' in n or 'tracked' in n:
            continue
        mask_dict[n] = m.cpu().detach()
    
    if args.model == 'fc':
        saved_mask_dir = f"../pai/saved_masks/{args.dataset}/{args.model_class}/{args.model}/{args.pruner}/compression_{args.compression}/L_{args.n_layers}/N_{args.hidden_dim}/" 
    else:    
        saved_mask_dir = f"../pai/saved_masks/{args.dataset}/{args.model_class}/{args.model}/{args.pruner}/compression_{args.compression}/" 
    os.makedirs(saved_mask_dir, exist_ok=True)
    torch.save(mask_dict, f'{saved_mask_dir}/mask_seed_{args.seed}.pt')
    # To load the mask:
    # a = torch.load('./test.pt', weights_only=True)
    # for k, v in a.items():
    #     print(k, v.shape)
        
