import os
import sys
import timm
import torch
import wandb
import random
import logging
import argparse
import numpy as np
import torch.nn.functional as F

from copy import deepcopy
from collections import defaultdict

from robustness import datasets
from robustness.tools.imagenet_helpers import ImageNetHierarchy
from torch.utils.data import DataLoader, Subset

# import from parent directory
sys.path.append(os.pardir)
from utils import fix_attention_layer
from utils.model_utils import extract_cls_layer
from utils.subset_utils import get_fewshot_indices, restrict_class_layer
from utils.transform_utils import get_transforms
from utils.model_utils import extract_cls_layer

from fsml.compression import create_pruners_from_yaml

from trainer import evaluate


def main():
    parser = argparse.ArgumentParser(description="Single-step pruning on ImageNet hierarchy.")
    # Data params
    parser.add_argument(
        '--data_dir',
        type=str,
        required=True,
        help="Dataset dir.",
    )
    parser.add_argument(
        '--data_info_dir',
        type=str,
        required=True,
        help="Dataset info dir.",
    )
    parser.add_argument(
        '--min_classes',
        default=5,
        type=int,
        help='Minimal number of classes used.'
    )
    # Few shot params
    parser.add_argument(
        '--samples_per_class',
        default=10,
        type=int
    )
    # Model params
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help="Model used",
    )
    parser.add_argument(
        '--pretrained',
        action='store_true',
        help="Whether to use pretrained model",
    )
    parser.add_argument(
        '--checkpoint_path',
        default=None,
        type=str,
        help='Path to model checkpoint'
    )
    # Training params
    parser.add_argument(
        '--train_batch_size',
        default=128,
        type=int
    )
    parser.add_argument(
        '--val_batch_size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--num_workers',
        default=4,
        type=int
    )
    # Dataloading params
    parser.add_argument(
        '--train_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--val_resize_size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--val_crop_size',
        default=224,
        type=int
    )
    # Sparsification parameters
    parser.add_argument(
        '--sparsification_config',
        type=str,
        default=None,
        help="Path to sparsification config.",
    )
    parser.add_argument(
        '--sparsity',
        default=None,
        type=float
    )
    parser.add_argument(
        '--create_calibration_loader',
        action='store_true',
        help="Whether to create additional calibration loader.",
    )
    parser.add_argument(
        '--calibration_dataset_size',
        default=None,
        type=int,
        help="Size of calibration dataset."
    )
    parser.add_argument(
        '--calibration_subset',
        default=None,
        type=str,
        help='Which subset is used from calibration (default: imagenet_subset).'
    )
    parser.add_argument(
        '--calibration_subset_choice',
        default='random',
        choices=['random', 'k-shot'],
        type=str,
        help='How to select samples in calibration set.'
    )
    # Logging params
    parser.add_argument(
        '--log_wandb',
        default=False,
        action="store_true",
        help="Log to W&B"
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--amp',
        default=False,
        action="store_true",
        help="Whether to use mixed precision"
    )
    parser.add_argument(
        '--ptc',
        default=False,
        action="store_true",
        help="Whether to calibrate on ImageNet data."
    )
    args = parser.parse_args()
    run(args)


logger = logging.getLogger(__name__)


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run(args):
    """Console script for one-shot pruning."""
    # fix random seed
    fix_seed(args.seed)
    # get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # init W&B logger
    if args.log_wandb:
        wandb.init(config=args)
    # init logger
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Model 
    model = timm.create_model(
        args.model,
        pretrained=args.pretrained,
        checkpoint_path=args.checkpoint_path,
    )
    model = fix_attention_layer(model)
    # put model on device
    model = model.to(device)

    ###########
    # Dataset #
    ###########

    # create imagenet hierarchy
    imagenet_hier = ImageNetHierarchy(args.data_dir, args.data_info_dir)
    # create transforms according to pretrained config
    transform_train, transform_test = get_transforms(model, 'timm', None)
    # select only those with at least args.min_classes
    selected_wnids = list(filter(lambda x: x[1] >= args.min_classes, imagenet_hier.wnid_sorted))
    print(f"Selected {len(selected_wnids)} subsets.")

    loss_fn = F.cross_entropy
    # dict with accuracy for dense and sparse models on subsets
    accuracy_data = defaultdict(dict)
    # make directory to save results
    os.makedirs(args.output_dir, exist_ok=True)

    for (wnid, ndesc_in, _) in selected_wnids:
        # copy model
        sparse_model = deepcopy(model)

        subset_name = imagenet_hier.wnid_to_name[wnid]
        print(f"Imagenet subset {subset_name}, #ImageNet descendants: {ndesc_in}")
        class_ranges, _ = imagenet_hier.get_subclasses([wnid])
        class_ranges = [[cls] for cls in class_ranges[0]]

        dataset = datasets.CustomImageNet(
            args.data_dir,
            class_ranges,
            transform_train=transform_train,
            transform_test=transform_test
        )
        # create loaders
        train_loader, val_loader = dataset.make_loaders(
            workers=args.num_workers, 
            batch_size=args.train_batch_size, 
            val_batch_size=args.val_batch_size
        )
        train_dataset = train_loader.dataset
         # create few shot dataset
        fewshot_indices = get_fewshot_indices(train_dataset, args.samples_per_class)
        fewshot_dataset = Subset(train_dataset, fewshot_indices)
        fewshot_loader = DataLoader(
            fewshot_dataset, 
            batch_size=args.train_batch_size, 
            shuffle=True,
            pin_memory=True,
            num_workers=args.num_workers,
            drop_last=False
        )

        logger.info(f"Dataset size: (train) {len(fewshot_dataset)}, (val) {len(val_loader.dataset)}")
        logger.info(f"Dataset number of classes: {dataset.num_classes}")
        # prepare classification head
        cls_layer = extract_cls_layer(sparse_model)
        restrict_class_layer(cls_layer, [class_ids[0] for class_ids in class_ranges])

        # evaluate before pruning
        val_stats_dense = evaluate(sparse_model, val_loader, loss_fn, device, args.amp)
        logger.info(f"Subtask accuracy dense: {val_stats_dense['acc1']:.3f}")

        ##########
        # Pruner #
        ##########

        # prepare calibration data
        if args.ptc:
             # for generic compression calibrate on ImageNet data
            imagenet_dataset = datasets.CustomImageNet(
                args.data_dir,
                [list(range(1000))],
                transform_train=transform_train,
                transform_test=transform_test
            )
            calibration_loader, _ = imagenet_dataset.make_loaders(
                workers=args.num_workers, 
                subset=len(fewshot_dataset),
                batch_size=args.train_batch_size
            )
        else:
            calibration_loader = fewshot_loader

        calibration_data = []
        for images, _ in calibration_loader:
            calibration_data.append(([images], {}))

        # prepare kwargs
        pruner_kwargs = {'data_loader': calibration_data}
        # init pruners
        pruners = create_pruners_from_yaml(
            sparse_model, 
            args.sparsification_config, 
            pruner_kwargs
        )
        # override sparsity in the config, if specified
        if args.sparsity:
            for pruner in pruners:
                if getattr(pruner.sparsity_schedule, 'sparsity'):
                    pruner.sparsity_schedule.sparsity = args.sparsity
                # apply pruner
                pruner.step()

        # evaluate before pruning
        val_stats_sparse = evaluate(sparse_model, val_loader, loss_fn, device, args.amp)
        logger.info(f"Subtask accuracy sparse (single step): {val_stats_sparse['acc1']:.3f}")

        log_dict = {
            'dense/acc1': val_stats_dense['acc1'],
            'dense/loss': val_stats_dense['loss'],
            'sparse/acc1': val_stats_sparse['acc1'],
            'sparse/loss': val_stats_sparse['loss'],
            'num_classes': ndesc_in
        }

        if args.log_wandb:
            wandb.log({f'{subset_name}/{k}': v for k, v in log_dict.items()})

        accuracy_data[subset_name] = log_dict

    torch.save(accuracy_data, os.path.join(args.output_dir, 'hierarchy_accuracies.pth'))
    logger.info(f'Done!')


if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover
