import os
import sys
import math
import random
import logging
import argparse
from copy import deepcopy

import timm
import torch
import wandb
import numpy as np
from robustness.datasets import CustomImageNet
from torch.utils.data import DataLoader, Subset

# import from parent directory
sys.path.append(os.pardir)
from utils import fix_attention_layer
from utils.subset_utils import get_fewshot_indices
from utils.optim import build_optimizer, build_lr_scheduler
from utils.transform_utils import get_transforms

from fsml.compression import create_pruners_from_yaml
from fsml.optim import wrap_optimizer

from trainer import Trainer
from transfer_datasets import DATASET_NUM_CLASSES, get_datasets


def main():
    parser = argparse.ArgumentParser(description="Pruning with KD on specific ImageNet subset.")
    # Data params
    parser.add_argument(
        '--dataset',
        type=str,
        required=True,
        help="Dataset for transfer learning.",
    )
    parser.add_argument(
        '--data_dir',
        type=str,
        required=True,
        help="Dataset dir.",
    )
    parser.add_argument(
        '--imagenet_data_dir',
        type=str,
        default=None,
        help="ImageNet Dataset dir (required for PTC).",
    )
    parser.add_argument(
        '--download',
        action='store_true',
        help="Whether to use download dataset",
    )
    parser.add_argument(
        '--augment',
        action='store_true',
        help="Whether to use apply augmentation",
    )
    # Few shot params
    parser.add_argument(
        '--samples_per_class',
        default=None,
        type=int
    )
    # Model params
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help="Model used",
    )
    parser.add_argument(
        '--checkpoint_path',
        default=None,
        type=str,
        help='Path to model checkpoint'
    )
    parser.add_argument(
        '--pretrained',
        action='store_true',
        help="Whether to use pretrained model",
    )
    # Training params
    parser.add_argument(
        '--num_train_epochs',
        default=10,
        type=int
    )
    parser.add_argument(
        '--num_train_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--train_batch_size',
        default=128,
        type=int,
        help="Batch size used both for training and calibration"
    )
    parser.add_argument(
        '--val_batch_size',
        default=256,
        type=int,
        help="Batch size used both for evaluation"
    )
    # Evaluation params
    parser.add_argument(
        '--eval_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--eval_steps',
        default=None,
        type=int
    )
    # Optimizer params
    parser.add_argument(
        '--optimizer',
        default='sgd',
        type=str,
        choices=['sgd', 'adam'],
        help="Optimizer used."
    )
    parser.add_argument(
        '--lr',
        default=0.1,
        type=float,
        help="Learning rate."
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta1',
        default=0.9,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--adam_beta2',
        default=0.999,
        type=float,
        help="SGD momentum."
    )
    parser.add_argument(
        '--weight_decay',
        default=0.0,
        type=float,
        help="Optimizer weight decay."
    )
    parser.add_argument(
        '--lr_scheduler',
        default=None,
        type=str,
        choices=['linear', 'cosine', 'cyclic_linear'],
        help="Learning rate scheduler used."
    )
    parser.add_argument(
        '--cycle_steps',
        default=None,
        type=int
    )
    # Dataloading params
    parser.add_argument(
        '--train_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--val_resize_size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--val_crop_size',
        default=224,
        type=int
    )
    parser.add_argument(
        '--num_workers',
        default=4,
        type=int
    )
    # Sparsification parameters
    parser.add_argument(
        '--sparsification_config',
        type=str,
        default=None,
        help="Path to sparsification config.",
    )
    parser.add_argument(
        '--sparsity',
        default=None,
        type=float
    )
    # Loss params
    parser.add_argument(
        '--base_loss',
        default='cross_entropy',
        type=str,
        choices=['cross_entropy', 'mse_loss'],
        help="Base loss."
    )
    parser.add_argument(
        '--output_kd_loss',
        default='kl_div',
        type=str,
        choices=['kl_div', 'mse_loss'],
        help="Output KD loss."
    )
    parser.add_argument(
        '--feat_kd_loss',
        default='mse_loss',
        type=str,
        choices=['mse_loss', 'mse_loss_norm'],
        help="Output feat loss."
    )
    # Distillation params
    parser.add_argument(
        "--distillation", 
        action='store_true',
    )
    parser.add_argument(
        "--temperature", 
        default=1.0, 
        type=float
    )
    parser.add_argument(
        '--lambda_base',
        default=1.0,
        type=float,
        help="Weight of the original loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_output',
        default=1.0,
        type=float,
        help="Weight of the output distillation loss in total loss."
    )
    parser.add_argument(
        '--lambda_kd_feat',
        default=1.0,
        type=float,
        help="Weight of the feature distillation loss in total loss."
    )
    parser.add_argument(
        "--teacher_model",
        default=None,
        type=str,
        help="Path or name of teacher model.",
    )
    parser.add_argument(
        '--feat_names',
        default='',
        type=str,
        help="Regular expression for features used in knowledge distillation."
    )
    # Logging params
    parser.add_argument(
        '--logging_epochs',
        default=1,
        type=int
    )
    parser.add_argument(
        '--logging_steps',
        default=None,
        type=int
    )
    parser.add_argument(
        '--log_wandb',
        default=False,
        action="store_true",
        help="Log to W&B"
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--amp',
        default=False,
        action="store_true",
        help="Whether to use mixed precision"
    )
    parser.add_argument(
        '--eval_before_training',
        default=False,
        action="store_true",
        help="Whether to evaluate model before training."
    )
    parser.add_argument(
        '--ptc',
        default=False,
        action="store_true",
        help="Whether to calibrate on ImageNet data."
    )
    parser.add_argument(
        '--freeze_batch_norm',
        default=False,
        action="store_true",
        help="Whether to freeze batch norm (for convnets)."
    )
    args = parser.parse_args()
    run(args)


logger = logging.getLogger(__name__)


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run(args):
    """Console script for one-shot pruning."""
    # fix random seed
    fix_seed(args.seed)
    # get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # init W&B logger
    if args.log_wandb:
        wandb.init(config=args)
    # init logger
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    #########
    # Model #
    #########

    num_classes = DATASET_NUM_CLASSES[args.dataset]

    model = timm.create_model(
        args.model,
        pretrained=args.pretrained,
        # checkpoint_path=args.checkpoint_path,
        num_classes=num_classes
    )
    
    if args.checkpoint_path:
        state_dict = torch.load(args.checkpoint_path)
        if state_dict.get("model_state_dict"):
            state_dict = state_dict["model_state_dict"]
        model.load_state_dict(state_dict)

    # transform attention layers to make them prunable
    model = fix_attention_layer(model)
    # put model on device
    model = model.to(device)

    ###########
    # Dataset #
    ###########

    # infer transform params
    if hasattr(model, 'pretrained_cfg'):
        crop_size = model.pretrained_cfg['input_size'][-1]
        resize_size = round(crop_size / model.pretrained_cfg['crop_pct'])
    else:
        crop_size = 224
        resize_size = 256

    # create transforms according to pretrained config
    transform_train, transform_test = get_transforms(model, 'timm')

    # create loaders
    train_dataset, val_dataset = get_datasets(
        dataset=args.dataset,
        data_dir=args.data_dir,
        train_crop_size=crop_size,
        test_resize_size=resize_size,
        test_crop_size=crop_size,
        download=args.download,
        augment=args.augment
    )
    # create few shot dataset
    if args.samples_per_class:
        fewshot_indices = get_fewshot_indices(train_dataset, args.samples_per_class)
        fewshot_dataset = Subset(train_dataset, fewshot_indices)
    else:
        fewshot_dataset = train_dataset
    
    fewshot_loader = DataLoader(
        fewshot_dataset, 
        batch_size=args.train_batch_size, 
        shuffle=True,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset, 
        batch_size=args.val_batch_size, 
        shuffle=False,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=False
    )

    logger.info(f"Dataset size: (train) {len(fewshot_dataset)}, (val) {len(val_dataset)}")

    # postprocess training args
    if args.num_train_steps:
        # override number of train epochs
        args.num_train_epochs = math.ceil(args.num_train_steps * len(fewshot_loader))
    elif args.num_train_epochs:
        args.num_train_steps = args.num_train_epochs * len(fewshot_loader)

    if args.eval_epochs:
        args.eval_steps = args.eval_epochs * len(fewshot_loader)

    if args.logging_epochs:
        args.logging_steps = args.logging_epochs * len(fewshot_loader)

    logger.info("-" * 10)
    logger.info(f"Num training steps: {args.num_train_steps}")
    logger.info(f"Eval steps: {args.eval_steps}")
    logger.info(f"Logging steps: {args.logging_steps}")
    logger.info("-" * 10)

    #############
    # Optimizer #
    #############

    optimizer = build_optimizer(model, args)
    lr_scheduler = build_lr_scheduler(optimizer, args)

    ################
    # Distillation #
    ################

    teacher_model = None
    if args.distillation:
        teacher_model = deepcopy(model)

    ##########
    # Pruner #
    ##########

    # prepare calibration data

    # for generic compression calibrate on ImageNet data
    if args.ptc:
        assert args.imagenet_data_dir is not None
        imagenet_dataset = CustomImageNet(
            args.imagenet_data_dir,
            [list(range(1000))],
            transform_train=transform_train,
            transform_test=transform_test
        )
        calibration_loader, _ = imagenet_dataset.make_loaders(
            workers=args.num_workers, 
            subset=len(fewshot_dataset),
            batch_size=args.train_batch_size
        )
    else:
        calibration_loader = fewshot_loader

    calibration_data = []
    for images, _ in calibration_loader:
        calibration_data.append(([images], {}))
    # prepare kwargs
    pruner_kwargs = {'data_loader': calibration_data}

    # init pruners
    pruners = []
    if args.sparsification_config:
        pruners = create_pruners_from_yaml(model, args.sparsification_config, pruner_kwargs)
        # override sparsity in the config, if specified
        if args.sparsity:
            for pruner in pruners:
                if getattr(pruner.sparsity_schedule, 'sparsity'):
                    pruner.sparsity_schedule.sparsity = args.sparsity
        # wrap optimizer
        optimizer = wrap_optimizer(optimizer, pruners)

    ###########################
    # Finetune after pruning  #
    ###########################

    # prepare dir
    os.makedirs(args.output_dir, exist_ok=True)

    trainer = Trainer(
        model,
        args,
        fewshot_loader,
        val_loader,
        optimizer=optimizer,
        pruners=pruners,
        lr_scheduler=lr_scheduler,
        teacher_model=teacher_model,
        device=device
    )    

    trainer.train()

    


if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover
