import os
import timm
import torch
import wandb
import pickle
import argparse
import numpy as np
import torch.nn as nn
import torchvision.transforms as T

from copy import deepcopy
from functools import partial
from timm.utils import AverageMeter
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader
from sparseml.pytorch.optim import ScheduledModifierManager

from utils.model_specific_utils import split_qkv


def parse_args():
    parser = argparse.ArgumentParser('One shot pruning.', add_help=False)
    # Model
    parser.add_argument('--model', default='deit_small_patch16_224', type=str)
    parser.add_argument('--checkpoint-path', default='', type=str, 
                        help='Path to model checkpoint')
    # Experiment
    parser.add_argument('--experiment', default='', type=str)
    parser.add_argument('--seed', default=42, type=int)
    # Path to data
    parser.add_argument('--data-dir', required=True, type=str)
    # Path to recipe
    parser.add_argument('--sparseml-recipe', required=True, type=str)
    # Loader params
    parser.add_argument('-b', '--batch_size', default=128, type=int)
    parser.add_argument('-vb', '--val_batch_size', default=128, type=int)
    parser.add_argument('--workers', default=4, type=int)
    parser.add_argument('--prefetch', default=2, type=int)
    # Sparsities
    parser.add_argument('--sparsities', nargs='+', required=True, type=float)
    # OBS loader params
    parser.add_argument('--gs-loader', action='store_true',
                        help='Whether to create additional loader for grad sampling.')
    # Calibration loader params
    parser.add_argument('--calibration-loader', action='store_true', 
                        help='Whether to create additional loader for model calibration.')
    parser.add_argument('--num-calibration-samples', default=1024, type=int)
    parser.add_argument('-cb', '--calib-batch-size', default=None, type=int)
    # Save arguments
    parser.add_argument('--save-dir', default='./output/one-shot', type=str, 
                        help='dir to save results')
    parser.add_argument('--save-model', action='store_true', 
                        help='Whether to save pruned model')
    # Logging
    parser.add_argument('--log-wandb', action='store_true')
    # Misc params
    parser.add_argument('--split-qkv', action='store_true')

    args = parser.parse_args()
    return args


def accuracy(logits, labels):
    return (torch.argmax(logits, dim=1) == labels).sum() / len(labels)


@torch.no_grad()
def val_epoch(
    model : nn.Module, 
    data_loader : DataLoader,
    criterion : nn.Module,
    device : torch.device
):
    model.eval()
    loss_m =  AverageMeter()
    acc_m = AverageMeter()
    
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        # get model output
        logits = model(images)
        # workaround for distilled models
        if isinstance(logits, tuple):
            logits = logits[0]
        # compute loss
        loss = criterion(logits, labels)         
        # statistics
        loss_m.update(loss.item())
        acc_m.update(accuracy(logits, labels).item())

    return {"loss" : loss_m.avg, "acc" : acc_m.avg}


if __name__ == '__main__':
    # parse args
    args = parse_args()
    # seed all
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # set num threads
    torch.set_num_threads(args.workers + 1)
    # init wandb
    if args.log_wandb:
        wandb.init(config=args)

    # Data
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transforms = T.Compose([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(p=0.5),
        T.ToTensor(),
        normalize,
    ])

    val_transforms = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        normalize,
    ])

    train_dataset = ImageFolder(root=f'{args.data_dir}/train', transform=train_transforms)
    val_dataset = ImageFolder(root=f'{args.data_dir}/val', transform=val_transforms)

    # create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=True, 
        num_workers=args.workers, 
        prefetch_factor=args.prefetch,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=args.val_batch_size, 
        shuffle=False, 
        num_workers=args.workers, 
        prefetch_factor=args.prefetch,
        pin_memory=True
    )

    # dummy loss 
    loss_fn = nn.CrossEntropyLoss()
    # model
    model_kw = dict(distillation=False) if 'efficientformer' in args.model else {}
    model = timm.create_model(
        args.model, 
        pretrained=True,
        checkpoint_path=args.checkpoint_path,
        **model_kw,
    )
    if args.split_qkv:
        model = split_qkv(model)

    model = model.to(device)
    # first evaluation
    val_acc = val_epoch(model, val_loader, loss_fn, device=device)['acc']
    print(f'Accuracy dense: {val_acc:.3f}')
    # make dir (if needed)
    os.makedirs(args.save_dir, exist_ok=True)
    experiment_data = {
        'sparsity': args.sparsities, 'val/acc' : []
    }

    manager_kwargs = {}
    # define for OBS/M-FAC pruner
    if args.gs_loader:
        def data_loader_builder(device=device, **kwargs):
            while True:
                for input, target in train_loader:
                    input, target = input.to(device), target.to(device)
                    yield [input], {}, target

        manager_kwargs['grad_sampler'] = {
            'data_loader_builder' : data_loader_builder, 
            'loss_fn' : loss_fn,
        }
    # define for AdaPrune/OBC pruner
    elif args.calibration_loader:
        calibration_dataset = Subset(
            train_dataset,
            np.random.choice(len(train_dataset), args.num_calibration_samples)
        )

        calibration_loader = DataLoader(
            calibration_dataset, 
            batch_size=args.calib_batch_size, 
            shuffle=True, 
            num_workers=args.workers, 
            prefetch_factor=args.prefetch,
            pin_memory=True
        )

        def data_loader_builder(device=device, **kwargs):
            for input, target in calibration_loader:
                input, target = input.to(device), target.to(device)
                yield [input], {}, target

        manager_kwargs['calibration_sampler'] = {
            'data_loader_builder' : data_loader_builder, 
            'loss_fn' : loss_fn,
        }

    for sparsity in args.sparsities:
        print(f'Sparsity {sparsity:.3f}')
        model_sparse = deepcopy(model)
        # create sparseml manager
        manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe)
        # update manager
        manager.modifiers[0].init_sparsity  = sparsity
        manager.modifiers[0].final_sparsity = sparsity
        # apply recipe
        manager.apply(
            model_sparse, 
            **manager_kwargs,
            finalize=True
        )
        # evaluate 
        val_acc = val_epoch(model_sparse, val_loader, loss_fn, device=device)['acc']
        # update experiment data
        experiment_data['val/acc'].append(val_acc)
        print(f'Test accuracy: {val_acc:.3f}')
        if args.log_wandb:
            wandb.log({'sparsity' : sparsity, 'val/acc': val_acc})

        if args.save_model:
            torch.save(
                model_sparse.state_dict(), 
                os.path.join(args.save_dir, f'{args.model}_sparsity={sparsity}.pth')
            )

    with open(f'{args.save_dir}/experiment_data.pkl', 'wb') as fout:
        pickle.dump(experiment_data, fout, protocol=pickle.HIGHEST_PROTOCOL)   

    print('Finished!') 
