import os
import math
import copy
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

from utils import load_model, load_dataset, get_save_name, print_pruned, get_pruning_ratio
from pruning import get_mask_saliency_AAS, get_mask_by_saliency, prune_saliency
from pruning import get_saliency_weight_magnitude, get_saliency_snip, get_saliency_lap_global, get_saliency_random
from pruning import get_alive_idx, prune_by_mask
from config import get_config


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        acc = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum()
            acc.append(correct_k.item())

        return acc


def train_hard(args, model, device, train_loader, optimizer, lr, epoch, max_epoch, mask):
    model.train()

    train_loss = 0

    for batch_idx, (data, target) in tqdm(
            enumerate(train_loader), total=len(train_loader),
            desc='Train epoch %d/%d (LR: %.4f)' % (epoch, max_epoch, lr), ncols=100, leave=True):

        step =  len(train_loader) * (epoch - 1) + (batch_idx + 1)
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()

        train_loss += loss.item()

        # Freeze masked conneciton
        if mask is not None:
            for name, p in model.named_parameters():
                if name in mask:
                    p.grad.data *= mask[name]

        optimizer.step()

    train_loss /= len(train_loader)

    return train_loss


# Evaluation
def test(args, model, device, data_loader):
    model.eval()

    test_loss, acc_1, acc_5 = 0, 0, 0

    with torch.no_grad():
        for batch_idx, (data, target) in tqdm(
            enumerate(data_loader), total=len(data_loader), ncols=100, leave=True):
            data, target = data.to(device), target.to(device)
            output = model(data)

            test_loss += F.cross_entropy(output, target).item()
            top_1, top_5 = accuracy(output, target, topk=(1, 5))

            correct_1, correct_5 = accuracy(output, target, topk=(1, 5))
            acc_1 += correct_1
            acc_5 += correct_5

    test_loss /= len(data_loader)
    acc_1 /= len(data_loader.dataset)
    acc_5 /= len(data_loader.dataset)
    
    return test_loss, acc_1, acc_5


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='All-Alive Pruning')

    parser.add_argument('--dataset', type=str, default='mnist',
                        help='Dataset')
    parser.add_argument('--data-dir', type=str, default='/data/dataset',
                        help='Directory for dataset')
    parser.add_argument('--arch', type=str, default='fc',
                        help='Model architecture')
    parser.add_argument('--mode', type=str, default='train',
                        help='train / prune / eval')
    parser.add_argument('--batch-size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--valid-ratio', type=float, default=0.1,
                        help='Rates of validation set')
    parser.add_argument('--epochs', type=int, default=14,
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=2e-3,
                        help='learning rate (default: 1.0)')
    parser.add_argument('--scheduler', type=str, default=None,
                        help='learning rate scheduler')
    parser.add_argument('--optim', type=str, default='sgd',
                        help='Optimizer')
    parser.add_argument('--gamma', type=float, default=0.1,
                        help='Learning rate step gamma (default: 0.1)')
    parser.add_argument('--step-size', type=float, default=30,
                        help='Scheduler Step size')
    parser.add_argument('--num-workers', type=int, default=6,
                        help='Num workers of data loader')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--single-gpu', action='store_true', default=False,
                        help='disables DataParallel')
    parser.add_argument('--device', type=str, default=None,
                        help='Select device')
    parser.add_argument('--seed', type=int, default=42,
                        help='random seed (default: 42)')
    parser.add_argument('--save-dir', type=str, default='save',
                        help='For Saving the current Model')

    # Evaluation Option
    parser.add_argument('--model-path', type=str, default=None,
                        help='Load model for evaluation')

    # Pruning Option
    parser.add_argument('--prune-percent', type=float, default=20,
                        help='Set Final Pruning Percentage')
    parser.add_argument('--prune-method', type=str, default='mp',
                        help='Pruning method (mp / snip / lap)')
    parser.add_argument('--pruning-step', type=int, default=1,
                        help='Get final pruning rates with (1 - (1-rate)^step)')
    parser.add_argument('--rewinding', type=str, default='lr',
                        help='Select device')

    # Lottery Tichet Hypothesis option (get pruning mask from best model, and initialize with init model)
    parser.add_argument('--origin-best-path', type=str, default=None,
                        help='Load original model')
    parser.add_argument('--origin-init-path', type=str, default=None,
                        help='Load original model')

    # IMP Option
    parser.add_argument('--imp', action='store_true', default=False,
                        help='Enable IMP')
    parser.add_argument('--imp-metric', type=str, default='acc',
                        help='(valid) loss / acc')
    parser.add_argument('--imp-from', type=str, default='init',
                        help='init / best_loss')
    parser.add_argument('--imp-index', type=int, default=0,
                        help='Index for IMP')

    # Alive connection
    parser.add_argument('--all-alive-pruning', action='store_true', default=False,
                        help='Find dead connection using weights/gradients')
    
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    # Set seeds
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    print("Set seed to ", args.seed)

    # Get config
    config = get_config(args)
    args.num_workers = config['num-workers']
    args.optim = config['optim']
    args.lr = config['lr']
    args.scheduler = config['scheduler']
    args.step_size = config['step_size']
    args.gamma = config['gamma']
    args.batch_size = config['batch-size']
    args.epochs = config['epochs']
    args.lr_rewinding = config['lr-rewinding']
    args.weight_rewinding = config['weight-rewinding']

    args.start_epoch = 0

    if args.rewinding == 'lr':
        args.start_epoch = args.lr_rewinding
        print(f"Use Learning Rate Rewinding (start from {args.lr_rewinding} epoch)")
    elif args.rewinding == 'weight':
        args.start_epoch = args.weight_rewinding
        print(f"Use Weight Rewinding (start from {args.weight_rewinding} epoch)")

    device = torch.device("cuda" if use_cuda else "cpu")

    # Load Dataset
    kwargs = {'num_workers': args.num_workers, 'pin_memory': True} if use_cuda else {}
    train_loader, valid_loader, test_loader, data_shape, num_classes = load_dataset(args, kwargs)

    # Evaluation
    if args.mode == 'eval':
        assert args.model_path, "Model path is required for evaluation"

        model = load_model(args, device, args.model_path)

        test_loss, acc_1, acc_5 = test(args, model, device, test_loader)
        _, _ = print_pruned(model)

        print('Test set - Average loss: {:.8f}, Top-1/5 Accuracy: {:.2f}%/{:.2f}%\n'.format(
                test_loss, 100 * acc_1, 100 * acc_5))
        return

    # Train from scratch
    if args.mode == 'train':
        model = load_model(args, device)
        mask = None

    # Pruning
    elif args.mode == 'prune':        
        assert args.prune_percent, "Pruning percentage is required for Pruning mode"
        assert args.prune_method, "Pruning method is required"

        # Set Pruning rates
        args.prune_percent = get_pruning_ratio(args.prune_percent, args.pruning_step)
        print("Pruning ratio: ", args.prune_percent)

        # Iterative Pruning Option (IMP)
        if args.imp:
            assert args.imp_metric and args.imp_from and args.imp_index, "Options for IMP are required."

            # Test for same pruning ratio
            imp_percent = args.prune_percent
            args.prune_percent = get_pruning_ratio(args.prune_percent, args.imp_index)
            print(f"Pruning rato for IMP {args.imp_index}: {args.prune_percent}")

            # Not first time
            if args.imp_index - 1:
                imp_name = get_save_name(args, prev_name=True, imp_percent=imp_percent)

                # In weight rewinding, get init values only from the original model
                if args.rewinding != 'weight':
                    args.origin_init_path = os.path.join(args.save_dir, imp_name + f'_{args.imp_from}.pt')
                args.origin_best_path = os.path.join(args.save_dir, imp_name + f'_{args.imp_metric}.pt')

        # Pruning method
        if args.prune_method == 'mp':
            origin_model = load_model(args, device, args.origin_best_path, 'original')
            get_saliency = get_saliency_weight_magnitude
            print(f"Magnitude Pruning - FROM:{args.origin_best_path}, INIT:{args.origin_init_path}")

        elif args.prune_method == 'snip':            
            origin_model = load_model(args, device, single_gpu=True)
            get_saliency = get_saliency_snip
            print("SNIP for pruning")

        elif args.prune_method == 'lap':
            origin_model = load_model(args, device, args.origin_best_path, 'original')
            get_saliency = get_saliency_lap_global
            print(f"Lookahead - FROM:{args.origin_best_path}, INIT:{args.origin_init_path}")

        elif args.prune_method == 'random':
            origin_model = load_model(args, device, name='original')
            get_saliency = get_saliency_random
            print("RANDOM Pruning")

        else:
            raise NotImplementedError

        # Load Model for training
        if args.prune_method == 'mp' or args.prune_method == 'lap':
            model = load_model(args, device, args.origin_init_path)

        elif args.prune_method == 'snip':
            origin_state_dict = OrderedDict()
            for k, v in origin_model.state_dict().items():
                if not 'module.' in k and not args.single_gpu:
                    name = 'module.' + k
                else:
                    name = k
                origin_state_dict[name] = v
            
            model = load_model(args, device)
            model.load_state_dict(origin_state_dict)

        elif args.prune_method == 'random':
            model = load_model(args, device)
        else:
            raise NotImplementedError

        # Mask arguments
        batch = copy.deepcopy(next(iter(train_loader)))

        mask_kwargs = {
            'batch': batch,
            'data_shape': data_shape,
            'num_classes': num_classes,
            'device': device
        }

        mask_kwargs['num_alive'], mask_kwargs['alive_idx'], mask_kwargs['alive_mask'] = get_alive_idx(args, origin_model)

        # Get saliency
        saliency = get_saliency(args, origin_model, **mask_kwargs)

        # Alive-connection
        if args.all_alive_pruning:
            print("Apply All-Alive Pruning")
            origin_mask = get_mask_by_saliency(args, saliency, **mask_kwargs)
            mask = get_mask_saliency_AAS(args, saliency, origin_model, **mask_kwargs)
        else:
            mask = get_mask_by_saliency(args, saliency, **mask_kwargs)

        # Delete origin model
        origin_model = origin_model.cpu()
        del origin_model

        prune_by_mask(model, mask)
    else:
        raise ValueError

    best_acc_1, best_acc_5 = 0., 0.
    best_loss = 10000000.
    best_acc_epoch, best_loss_epoch = 0, 0

    # Select optimizer
    if args.optim == 'sgd_m':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
    elif args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    args.num_alive, args.num_pruned = print_pruned(model)

    if args.imp:
        save_name = get_save_name(args, imp_percent=imp_percent)
    else:
        save_name = get_save_name(args)
        
    torch.save(model.state_dict(), os.path.join(args.save_dir, save_name + '_init.pt'))

    # Convert mask to torch Tensor
    if mask is not None:
        for name in mask:
            assert ((mask[name]==0) | (mask[name]==1)).all(), "Mask should be binary format (zero or one)"
            mask[name] = torch.from_numpy(mask[name]).float().to(device)

    if args.prune_method == 'snip' and not args.single_gpu:
        print("Change mask's name for SNIP pruning method")
        snip_mask = {}
        for name in mask:
            snip_mask['module.' + name] = copy.deepcopy(mask[name])
        mask = snip_mask

    # Train model
    for epoch in range(args.start_epoch + 1, args.epochs + 1):
        lr = adjust_learning_rate(args, optimizer, epoch)

        train_loss = train_hard(args, model, device, train_loader, optimizer, lr, epoch, args.epochs, mask)
        test_loss, acc_1, acc_5 = test(args, model, device, valid_loader)

        # Save Best Acc_1 model
        if best_acc_1 < acc_1:
            torch.save(model.state_dict(), os.path.join(args.save_dir, save_name + '_best_acc_1.pt'))
            best_acc_1 = acc_1
            best_acc_1_epoch = epoch

        # Save Best Acc_1 model
        if best_acc_5 < acc_5:
            # torch.save(model.state_dict(), os.path.join(args.save_dir, save_name + '_best_acc_5.pt'))
            best_acc_5 = acc_5
            best_acc_5_epoch = epoch

        # Save Best loss model
        if best_loss > test_loss:
            torch.save(model.state_dict(), os.path.join(args.save_dir, save_name + '_best_loss.pt'))
            best_loss = test_loss
            best_loss_epoch = epoch
        # """ 

        print('--- [Loss] Train/Test : {:.4f}/{:.4f} [Acc] Top-1/5: {:.2f}%/{:.2f}% [Best] Top-1/5 Acc: {:.2f}%/{:.2f}%, Loss: {:.4f}'.format(
        train_loss, test_loss, 100. * acc_1, 100. * acc_5, 100. * best_acc_1, 100. * best_acc_5, best_loss))

    # Save final model
    torch.save(model.state_dict(), os.path.join(args.save_dir, save_name + f'_{args.epochs}epoch.pt'))
    _, _ = print_pruned(model)
    print(f"--- Train Done. Best Top-1/5 Accuracy : {best_acc_1* 100}%({best_acc_1_epoch} ep)/{best_acc_5* 100}%({best_acc_5_epoch} ep) / Best Loss: {best_loss} (in {best_loss_epoch} epoch) " )


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.lr

    if args.scheduler == 'steplr':
        lr = args.lr * (args.gamma ** (epoch // args.step_size))

    elif args.scheduler == 'imagenet':
        if 1 <= epoch < 5:
            lr = args.lr * (epoch / 5)
        elif 5 <= epoch < 30:
            lr = args.lr
        elif 30 <= epoch < 60:
            lr = args.lr * 0.1
        elif 60 <= epoch < 80:
            lr = args.lr * 0.01
        elif 80 <= epoch:
            lr = args.lr * 0.001

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


if __name__ == '__main__':
    main()