import os
import sys
import math
import random
import logging
import argparse
from copy import deepcopy

import timm
import wandb
import torch
import numpy as np
import torchvision.transforms as T
from torchvision.datasets import INaturalist
from torch.utils.data import DataLoader, Subset
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

# import from parent directory
sys.path.append(os.pardir)
from utils import fix_attention_layer
from utils.model_utils import extract_cls_layer
from utils.subset_utils import get_fewshot_indices, restrict_class_layer
from utils.optim import build_optimizer, build_lr_scheduler

from trainer import Trainer
from inaturalist_subset import iNaturalistSubset

sys.path.append('fsml')
from fsml.compression import create_pruners_from_yaml
from fsml.optim import wrap_optimizer


def main():
    parser = argparse.ArgumentParser(description="One-shot pruning on imagenet subset.")
    # Data params
    parser.add_argument(
        '--subset_category_type',
        type=str,
        default=None,
        help="INaturalist category used for subset.",
    )
    parser.add_argument(
        '--subset_category_id',
        type=int,
        default=0,
        help="id inside INaturalist category used for subset.",
    )
    parser.add_argument(
        '--data_dir',
        type=str,
        required=True,
        help="Dataset dir.",
    )
    # Few shot params
    parser.add_argument(
        '--samples_per_class',
        default=10,
        type=int
    )
    # Model params
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help="Model used",
    )
    parser.add_argument(
        '--checkpoint_path',
        default=None,
        type=str,
        help='Path to model checkpoint'
    )
    # Training params
    parser.add_argument(
        '--num_train_epochs',
        default=10,
        type=int
    )
    parser.add_argument(
        '--num_train_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--train_batch_size',
        default=128,
        type=int
    )
    parser.add_argument(
        '--val_batch_size',
        default=256,
        type=int
    )
    # Evaluation params
    parser.add_argument(
        '--eval_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--eval_steps',
        default=None,
        type=int
    )
    # Optimizer params
    parser.add_argument(
        '--optimizer',
        default='sgd',
        type=str,
        choices=['sgd', 'adam'],
        help="Optimizer used."
    )
    parser.add_argument(
        '--lr',
        default=0.1,
        type=float,
        help="Learning rate."
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta1',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta2',
        default=0.999,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--weight_decay',
        default=0.0,
        type=float,
        help="Optimizer weight decay."
    )
    parser.add_argument(
        '--lr_scheduler',
        default=None,
        type=str,
        choices=['linear', 'cosine', 'cyclic_linear'],
        help="Learning rate scheduler used."
    )
    parser.add_argument(
        '--cycle_steps',
        default=None,
        type=int
    )
    # Dataloading params
    parser.add_argument(
        '--train_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--val_resize_size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--val_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--num_workers',
        default=4,
        type=int
    )
    # Sparsification parameters
    parser.add_argument(
        '--sparsification_config',
        type=str,
        default=None,
        help="Path to sparsification config.",
    )
    parser.add_argument(
        '--sparsity',
        default=None,
        type=float
    )
    # Loss params
    parser.add_argument(
        '--base_loss',
        default='cross_entropy',
        type=str,
        choices=['cross_entropy', 'mse_loss'],
        help="Base loss."
    )
    parser.add_argument(
        '--output_kd_loss',
        default='kl_div',
        type=str,
        choices=['kl_div', 'mse_loss'],
        help="Output KD loss."
    )
    parser.add_argument(
        '--feat_kd_loss',
        default='mse_loss',
        type=str,
        choices=['mse_loss', 'mse_loss_norm'],
        help="Output feat loss."
    )
    # Distillation params
    parser.add_argument(
        "--distillation", 
        action='store_true',
    )
    parser.add_argument(
        "--temperature", 
        default=1.0, 
        type=float
    )
    parser.add_argument(
        '--lambda_base',
        default=1.0,
        type=float,
        help="Weight of the original loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_output',
        default=1.0,
        type=float,
        help="Weight of the output distillation loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_feat',
        default=1.0,
        type=float,
        help="Weight of the feature distillation loss in total loss."
    )
    parser.add_argument(
        "--teacher_model",
        default=None,
        type=str,
        help="Path or name of teacher model.",
    )
    parser.add_argument(
        '--feat_names',
        default='',
        type=str,
        help="Regular expression for features used in knowledge distillation."
    )
    # Logging params
    parser.add_argument(
        '--logging_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--logging_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--log_wandb',
        default=False,
        action="store_true",
        help="Log to W&B"
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--amp',
        default=False,
        action="store_true",
        help="Whether to use mixed precision"
    )
    parser.add_argument(
        '--eval_before_training',
        default=False,
        action="store_true",
        help="Whether to evaluate model before training."
    )
    parser.add_argument(
        '--ptc',
        default=False,
        action="store_true",
        help="Whether to calibrate on ImageNet data."
    )
    parser.add_argument(
        '--calibration_with_labels',
        default=False,
        action="store_true",
        help="Whether to calibrate with labers (for Fisher-based pruners)."
    )
    parser.add_argument(
        '--calibration_batch_size',
        type=int,
        default=None,
        help="Whether to override batch size for calibration."
    )
    parser.add_argument(
        '--freeze_batch_norm',
        default=False,
        action="store_true",
        help="Whether to freeze batch norm."
    )
    args = parser.parse_args()
    run(args)


logger = logging.getLogger(__name__)


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run(args):
    """Console script for one-shot pruning."""
    # set seed.
    fix_seed(args.seed)
    # get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # init W&B logger
    if args.log_wandb:
        wandb.init(config=args)
    # init logger
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Model 
    model = timm.create_model(
        args.model,
        pretrained=False,
        num_classes=10000,
        checkpoint_path=args.checkpoint_path,
    )
    model = fix_attention_layer(model)
    # put model on device
    model = model.to(device)

    ###########
    # Dataset #
    ###########

    print(f'Pruning on {args.subset_category_type}:{args.subset_category_id} INaturalist subset.')

    # use imagenet normalization
    mean = IMAGENET_DEFAULT_MEAN
    std = IMAGENET_DEFAULT_STD

    # create transform
    transform_train = T.Compose([
        T.RandomResizedCrop(args.train_crop_size, interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std)]
    )

    transform_test = T.Compose([
        T.Resize(args.val_resize_size, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(args.val_crop_size),
        T.ToTensor(),
        T.Normalize(mean, std)
    ])

    train_dataset = INaturalist(
        root=args.data_dir,
        version='2021_train',
        target_type='full',
        transform=transform_train,
    )

    val_dataset = INaturalist(
        root=args.data_dir,
        version='2021_valid',
        target_type='full',
        transform=transform_test,
    )

    # get subsets
    if args.subset_category_type is not None:
        train_subset = iNaturalistSubset(train_dataset, args.subset_category_type, args.subset_category_id)
        if args.samples_per_class:
            fewshot_indices = get_fewshot_indices(train_subset, args.samples_per_class)
            train_subset = Subset(train_subset, fewshot_indices)
        val_subset = iNaturalistSubset(val_dataset, args.subset_category_type, args.subset_category_id)
    else:
        train_subset = train_dataset
        val_subset = val_dataset

    train_loader = DataLoader(
        train_subset,
        batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_subset,
        batch_size=args.val_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
        drop_last=False
    )

    print(f"Dataset size: (train) {len(train_subset)}, (val) {len(val_subset)}")

    # prepare classification head
    cls_layer = extract_cls_layer(model)
    restrict_class_layer(cls_layer, val_subset.orig_class_ids)

    # postprocess training args
    if args.num_train_steps:
        # override number of train epochs
        args.num_train_epochs = math.ceil(args.num_train_steps * len(train_loader))
    elif args.num_train_epochs:
        args.num_train_steps = args.num_train_epochs * len(train_loader)

    if args.eval_epochs:
        args.eval_steps = args.eval_epochs * len(train_loader)

    if args.logging_epochs:
        args.logging_steps = args.logging_epochs * len(train_loader)

    logger.info("-" * 10)
    logger.info(f"Num training steps: {args.num_train_steps}")
    logger.info(f"Eval steps: {args.eval_steps}")
    logger.info(f"Logging steps: {args.logging_steps}")
    logger.info("-" * 10)

    #############
    # Optimizer #
    #############

    optimizer = build_optimizer(model, args)
    lr_scheduler = build_lr_scheduler(optimizer, args)

    ################
    # Distillation #
    ################

    teacher_model = None
    if args.distillation:
        teacher_model = deepcopy(model)

    ##########
    # Pruner #
    ##########

    # prepare calibration data

    # for generic compression calibrate on iNaturalist data
    if args.ptc:
        calibration_indices = np.random.randint(
            low=0, 
            high=len(train_dataset), 
            size=len(train_subset)
        )
        calibration_dataset = Subset(train_dataset, calibration_indices)
        calibration_loader = DataLoader(
            calibration_dataset,
            batch_size=args.train_batch_size,
            num_workers=args.num_workers,
            shuffle=False,
            pin_memory=True,
            drop_last=True
        )
        # sample indices at random
        train_subset = Subset(train_subset, fewshot_indices)
    else:
        calibration_loader = train_loader

    calibration_data = []
    for images, _ in calibration_loader:
        calibration_data.append(([images], {}))
    # prepare kwargs
    pruner_kwargs = {'data_loader': calibration_data}

    # init pruners
    pruners = []
    if args.sparsification_config:
        pruners = create_pruners_from_yaml(model, args.sparsification_config, pruner_kwargs)
        # override sparsity in the config, if specified
        if args.sparsity:
            for pruner in pruners:
                if getattr(pruner.sparsity_schedule, 'sparsity'):
                    pruner.sparsity_schedule.sparsity = args.sparsity
        # wrap optimizer
        optimizer = wrap_optimizer(optimizer, pruners)

    ###########################
    # Finetune after pruning  #
    ###########################

    # prepare dir
    os.makedirs(args.output_dir, exist_ok=True)

    trainer = Trainer(
        model,
        args,
        train_loader,
        val_loader,
        optimizer=optimizer,
        pruners=pruners,
        lr_scheduler=lr_scheduler,
        teacher_model=teacher_model,
        device=device
    )    

    trainer.train()

if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover

