import os
import sys
import math
import random
import logging
import argparse
from copy import deepcopy

import timm
import torch
import wandb
import open_clip
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from robustness import datasets
from robustness.tools.imagenet_helpers import ImageNetHierarchy

# import from parent directory
sys.path.append(os.pardir)
from utils import fix_attention_layer
from utils.model_utils import extract_cls_layer
from utils.subset_utils import get_fewshot_indices, restrict_class_layer
from utils.optim import build_optimizer, build_lr_scheduler
from utils.transform_utils import get_transforms

from fsml.compression import create_pruners_from_yaml
from fsml.optim import wrap_optimizer

from trainer import Trainer


def main():
    parser = argparse.ArgumentParser(description="Pruning with KD on specific ImageNet subset.")
    # Data params
    parser.add_argument(
        '--data_dir',
        type=str,
        required=True,
        help="Dataset dir.",
    )
    parser.add_argument(
        '--data_info_dir',
        type=str,
        required=True,
        help="Dataset info dir.",
    )
    parser.add_argument(
        '--imagenet_subset',
        type=str,
        default="n04524313",
        help="ImageNet subset used.",
    )
    # Few shot params
    parser.add_argument(
        '--samples_per_class',
        default=10,
        type=int
    )
    # Model params
    parser.add_argument(
        '--model_source',
        type=str,
        default='timm',
        choices=['timm', 'open_clip'],
        help="Model used",
    )
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help="Model used",
    )
    parser.add_argument(
        '--checkpoint_path',
        default=None,
        type=str,
        help='Path to model checkpoint'
    )
    parser.add_argument(
        '--pretrained',
        action='store_true',
        help="Whether to use pretrained model",
    )
    # Training params
    parser.add_argument(
        '--num_train_epochs',
        default=10,
        type=int
    )
    parser.add_argument(
        '--num_train_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--train_batch_size',
        default=128,
        type=int,
        help="Batch size used both for training and calibration"
    )
    parser.add_argument(
        '--val_batch_size',
        default=256,
        type=int,
        help="Batch size used both for evaluation"
    )
    parser.add_argument(
        '--calibration_batch_size',
        default=None,
        type=int,
        help="Overriden batch size for calibration loader."
    )
    # Evaluation params
    parser.add_argument(
        '--eval_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--eval_steps',
        default=None,
        type=int
    )
    # Optimizer params
    parser.add_argument(
        '--optimizer',
        default='sgd',
        type=str,
        choices=['sgd', 'adam'],
        help="Optimizer used."
    )
    parser.add_argument(
        '--lr',
        default=0.1,
        type=float,
        help="Learning rate."
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta1',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta2',
        default=0.999,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--weight_decay',
        default=0.0,
        type=float,
        help="Optimizer weight decay."
    )
    parser.add_argument(
        '--lr_scheduler',
        default=None,
        type=str,
        choices=['linear', 'cosine', 'cyclic_linear'],
        help="Learning rate scheduler used."
    )
    parser.add_argument(
        '--cycle_steps',
        default=None,
        type=int
    )
    # Dataloading params
    parser.add_argument(
        '--train_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--val_resize_size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--val_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--num_workers',
        default=4,
        type=int
    )
    # Sparsification parameters
    parser.add_argument(
        '--sparsification_config',
        type=str,
        default=None,
        help="Path to sparsification config.",
    )
    parser.add_argument(
        '--sparsity',
        default=None,
        type=float
    )
    # Loss params
    parser.add_argument(
        '--base_loss',
        default='cross_entropy',
        type=str,
        choices=['cross_entropy', 'mse_loss'],
        help="Base loss."
    )
    parser.add_argument(
        '--output_kd_loss',
        default='kl_div',
        type=str,
        choices=['kl_div', 'mse_loss'],
        help="Output KD loss."
    )
    parser.add_argument(
        '--feat_kd_loss',
        default='mse_loss',
        type=str,
        choices=['mse_loss', 'mse_loss_norm'],
        help="Output feat loss."
    )
    # Distillation params
    parser.add_argument(
        "--distillation", 
        action='store_true',
    )
    parser.add_argument(
        "--temperature", 
        default=1.0, 
        type=float
    )
    parser.add_argument(
        '--lambda_base',
        default=1.0,
        type=float,
        help="Weight of the original loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_output',
        default=1.0,
        type=float,
        help="Weight of the output distillation loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_feat',
        default=1.0,
        type=float,
        help="Weight of the feature distillation loss in total loss."
    )
    parser.add_argument(
        "--teacher_model",
        default=None,
        type=str,
        help="Path or name of teacher model.",
    )
    parser.add_argument(
        '--feat_names',
        default='',
        type=str,
        help="Regular expression for features used in knowledge distillation."
    )
    # Logging params
    parser.add_argument(
        '--logging_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--logging_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--log_wandb',
        default=False,
        action="store_true",
        help="Log to W&B"
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--amp',
        default=False,
        action="store_true",
        help="Whether to use mixed precision"
    )
    parser.add_argument(
        '--eval_before_training',
        default=False,
        action="store_true",
        help="Whether to evaluate model before training."
    )
    parser.add_argument(
        '--ptc',
        default=False,
        action="store_true",
        help="Whether to calibrate on ImageNet data."
    )
    parser.add_argument(
        '--freeze_batch_norm',
        default=False,
        action="store_true",
        help="Whether to freeze batch norm (for convnets)."
    )
    parser.add_argument(
        '--calibrate_with_targets',
        default=False,
        action="store_true",
        help="Whether to include targets in calibration loader (for Fisher pruners)."
    )

    args = parser.parse_args()
    run(args)


logger = logging.getLogger(__name__)


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run(args):
    """Console script for one-shot pruning."""
    # fix random seed
    fix_seed(args.seed)
    # get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # init W&B logger
    if args.log_wandb:
        wandb.init(config=args)
    # init logger
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    #########
    # Model #
    #########

    transform_test = None
    if args.model_source == 'timm':
        model = timm.create_model(
            args.model,
            pretrained=args.pretrained,
            checkpoint_path=args.checkpoint_path,
        )
    elif args.model_source == 'open_clip':
        model, _, transform_test = open_clip.create_model_and_transforms(args.model, args.clip_pretrain)

    # transform attention layers to make them prunable
    model = fix_attention_layer(model)
    # put model on device
    model = model.to(device)

    ###########
    # Dataset #
    ###########

    # create imagenet hierarchy
    imagenet_hier = ImageNetHierarchy(args.data_dir, args.data_info_dir)
    try:
        assert args.imagenet_subset.startswith('n0')
        class_ranges, _ = imagenet_hier.get_subclasses([args.imagenet_subset])
        class_ranges = [[cls] for cls in class_ranges[0]]
        # update the name of dataset in config
    except:
        raise ValueError("Unknown subset")
    subset_name = imagenet_hier.wnid_to_name[args.imagenet_subset]
    logger.info(f'Few-shot compression on {subset_name} ImageNet subset.')

    # create transforms according to pretrained config
    transform_train, transform_test = get_transforms(model, args.model_source, transform_test)

    dataset = datasets.CustomImageNet(
        args.data_dir,
        class_ranges,
        transform_train=transform_train,
        transform_test=transform_test
    )

    # create loaders
    train_loader, val_loader = dataset.make_loaders(
        workers=args.num_workers, batch_size=args.train_batch_size, val_batch_size=args.val_batch_size
    )
    # create few shot dataset
    train_dataset = train_loader.dataset
    fewshot_indices = get_fewshot_indices(train_dataset, args.samples_per_class)
    fewshot_dataset = Subset(train_dataset, fewshot_indices)
    fewshot_loader = DataLoader(
        fewshot_dataset, 
        batch_size=args.train_batch_size, 
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=True
    )

    logger.info(f"Dataset size: (train) {len(fewshot_dataset)}, (val) {len(val_loader.dataset)}")
    logger.info(f"Dataset number of classes: {dataset.num_classes}")

    # prepare classification head
    cls_layer = extract_cls_layer(model)
    restrict_class_layer(cls_layer, [class_ids[0] for class_ids in class_ranges])

    # postprocess training args
    if args.num_train_steps:
        # override number of train epochs
        args.num_train_epochs = math.ceil(args.num_train_steps * len(fewshot_loader))
    elif args.num_train_epochs:
        args.num_train_steps = args.num_train_epochs * len(fewshot_loader)

    if args.eval_epochs:
        args.eval_steps = args.eval_epochs * len(fewshot_loader)

    if args.logging_epochs:
        args.logging_steps = args.logging_epochs * len(fewshot_loader)

    logger.info("-" * 10)
    logger.info(f"Num training steps: {args.num_train_steps}")
    logger.info(f"Eval steps: {args.eval_steps}")
    logger.info(f"Logging steps: {args.logging_steps}")
    logger.info("-" * 10)

    #############
    # Optimizer #
    #############

    optimizer = build_optimizer(model, args)
    lr_scheduler = build_lr_scheduler(optimizer, args)

    ################
    # Distillation #
    ################

    teacher_model = None
    if args.distillation:
        teacher_model = deepcopy(model)

    ##########
    # Pruner #
    ##########

    # prepare calibration data
    args.calibration_batch_size = args.calibration_batch_size or args.train_batch_size

    # for generic compression calibrate on ImageNet data
    if args.ptc:
        imagenet_dataset = datasets.CustomImageNet(
            args.data_dir,
            [list(range(1000))],
            transform_train=transform_train,
            transform_test=transform_test
        )
        calibration_loader, _ = imagenet_dataset.make_loaders(
            workers=args.num_workers, 
            subset=len(fewshot_dataset),
            batch_size=args.calibration_batch_size,
            shuffle_val=False
        )
    # in case calibration loader has different batch_size
    else:
        calibration_loader = DataLoader(
            fewshot_dataset, 
            batch_size=args.calibration_batch_size, 
            shuffle=False,
            pin_memory=True,
            num_workers=args.num_workers,
            drop_last=False
        )

    calibration_data = []
    for images, targets in calibration_loader:
        # include target in calibration data
        if args.calibrate_with_targets:
            calibration_data.append(([images], {}, targets))
        else:
            calibration_data.append(([images], {}))
    # prepare kwargs
    pruner_kwargs = {'data_loader': calibration_data, 'loss_fn': F.cross_entropy}

    # init pruners
    pruners = []
    if args.sparsification_config:
        pruners = create_pruners_from_yaml(model, args.sparsification_config, pruner_kwargs)
        # override sparsity in the config, if specified
        if args.sparsity:
            for pruner in pruners:
                if getattr(pruner.sparsity_schedule, 'sparsity'):
                    pruner.sparsity_schedule.sparsity = args.sparsity
        # wrap optimizer
        optimizer = wrap_optimizer(optimizer, pruners)

    ###########################
    # Finetune after pruning  #
    ###########################

    # prepare dir
    os.makedirs(args.output_dir, exist_ok=True)

    trainer = Trainer(
        model,
        args,
        fewshot_loader,
        val_loader,
        optimizer=optimizer,
        pruners=pruners,
        lr_scheduler=lr_scheduler,
        teacher_model=teacher_model,
        device=device
    )    

    trainer.train()


if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover