from __future__ import print_function
from __future__ import absolute_import

import os
import csv
import copy
import time
import data
import torch
import models
import logging
import importlib
import numpy as np
import torch.nn as nn
from pathlib import Path
from args import parse_args
from utils.s2ap import TradesS2AP  
from utils.utils import extract_masks
from utils.hw import hw_loss, hw_flops_loss
from torch.utils.tensorboard import SummaryWriter
from utils.schedules import get_lr_policy, get_optimizer

from utils.logging import (
    parse_prune_stg,
    save_checkpoint,
    create_subdirs,
    clone_results_to_latest_subdir,
    parse_configs_file,
)
from utils.model import (
    get_layers,
    prepare_model,
    initialize_scaled_score,
    initialize_stg_rate,
    display_loadrate,
    show_gradients,
    current_model_pruned_fraction,
    sanity_check_paramter_updates,
    map_shortcut_rate
)
# s2ap 
args = parse_args()

parse_configs_file(args)

def main():

    # sanity checks
    if args.exp_mode in ["score_prune", "score_finetune", "rate_prune", "harp_prune", "harp_finetune"] and not args.resume:
        assert args.source_net, "Provide checkpoint to prune/finetune"

    # create resutls dir (for logs, checkpoints, etc.)
    result_main_dir = os.path.join(Path(args.result_dir), args.exp_name, args.exp_mode)

    if os.path.exists(result_main_dir):
        n = len(next(os.walk(result_main_dir))[-2])  # prev experiments with same name
        result_sub_dir = os.path.join(
            result_main_dir,
            "{}--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format(
                n + 1,
                args.k,
                args.trainer,
                args.lr,
                args.epochs,
                args.warmup_lr,
                args.warmup_epochs,
            ),
        )
    else:
        os.makedirs(result_main_dir, exist_ok=True)
        result_sub_dir = os.path.join(
            result_main_dir,
            "1--k-{:.2f}_trainer-{}_lr-{}_epochs-{}_warmuplr-{}_warmupepochs-{}".format(
                args.k,
                args.trainer,
                args.lr,
                args.epochs,
                args.warmup_lr,
                args.warmup_epochs,
            ),
        )
    create_subdirs(result_sub_dir)

    
    # create csv for eigenvalues 
    if args.eigenvalues: 
        eigen_csv_path = os.path.join(result_main_dir, "eigenvalues.csv")
        with open(eigen_csv_path, mode='w', newline='') as csv_file:
            writer = csv.writer(csv_file)
            writer.writerow(["epoch", "iteration", "top_eigenvalue"])  # Customize headers
            
    # pass strategy if sparsity is non uniform
    if args.exp_mode in ["rate_prune", "harp_prune"]:
        parse_prune_stg(args)

    # add logger
    logging.basicConfig(level=logging.INFO, format="%(message)s")
    logger = logging.getLogger()
    logger.addHandler(
        logging.FileHandler(os.path.join(result_sub_dir, "setup.log"), "a")
    )
    logger.info(args)

    # seed cuda
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # select GPUs
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    num_gpus = len(args.gpu.strip().split(","))
    gpu_list = [i for i in range(num_gpus)]
    device = torch.device(f"cuda:{args.gpu}" if use_cuda else "cpu")

    # dataloader
    D = data.__dict__[args.dataset](args)
    train_loader, test_loader = D.data_loaders()

    # Create model
    cl, ll = get_layers(args.layer_type)
    print(gpu_list)
    if len(gpu_list) > 1:
        print("Using multiple GPUs")
        model = nn.DataParallel(
            models.__dict__[args.arch](
                cl, ll, args.init_type, mean=D.mean.to(device), std=D.std.to(device), num_classes=args.num_classes,
                prune_reg=args.prune_reg, task_mode=args.exp_mode, normalize=args.normalize
            ),
            device_ids=gpu_list,
            output_device=device
        ).to(device)
    else:
        model = models.__dict__[args.arch](
            cl, ll, args.init_type, mean=D.mean.to(device), std=D.std.to(device), num_classes=args.num_classes,
            prune_reg=args.prune_reg, task_mode=args.exp_mode, normalize=args.normalize
        ).to(device)
    logger.info(model)

    # Customize models for training/pruning/fine-tuning
    prepare_model(model, args, device)
    
    # Setup tensorboard writer
    writer = SummaryWriter(os.path.join(result_sub_dir, "tensorboard"))

    logger.info(
        f"Dataset: {args.dataset}, D: {D}, num_train: {len(train_loader.dataset)}, num_test:{len(test_loader.dataset)}")

    # autograd
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer(model, args)
    lr_policy = get_lr_policy(args.lr_schedule)(optimizer, args)
    logger.info([criterion, optimizer, lr_policy])

    # train & val method
    trainer = importlib.import_module(f"trainer.{args.trainer}").train
    val = getattr(importlib.import_module("utils.eval"), args.val_method)

    if args.exp_mode in ['pretrain']:
        best_prec1 = 0
    if args.exp_mode in ['pretrain'] and args.resume != '':
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            model_dict = model.state_dict()
            checkpoint = torch.load(args.resume, map_location=device)
            checkpoint_dict = checkpoint['state_dict']
            if args.dataset == 'imagenet' and args.gpu.find(',') == -1:
                checkpoint_dict = {k.replace("module.", ""): v for k, v in checkpoint_dict.items()
                                   if k.find('popup_scores') == -1 and k.find('sub_block') == -1}
            elif args.dataset != 'imagenet' and args.gpu.find(',') != -1:
                checkpoint_dict = {f"module.{k}": v for k, v in checkpoint_dict.items()
                                   if k.find('popup_scores') == -1 and k.find('sub_block') == -1}
            else:
                checkpoint_dict = {k.replace("module.basic_model.", ""): v for k, v in checkpoint_dict.items()
                                   if k.find('popup_scores') == -1 and k.find('sub_block') == -1}

            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model_dict.update(checkpoint_dict)
            model.load_state_dict(model_dict)
            optimizer.load_state_dict(checkpoint["optimizer"])
            logger.info(
                f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']}, best_acc1 {best_prec1:.2f}%)"
            )
        else:
            logger.info("=> no checkpoint found at '{}'".format(args.resume))
            exit(0)

    elif args.exp_mode not in ['pretrain']:

        assert args.source_net != '' or args.resume != '', (
            "Incorrect setup: "
            "resume => required to resume a previous experiment (loads all parameters)|| "
            "source_net => required to start pruning/fine-tuning from a source model (only load state_dict)"
        )

        # Load source_net (if checkpoint provided). Only load the state_dict (required for pruning and fine-tuning)
        if args.source_net and args.resume == '':
            if os.path.isfile(args.source_net):
                logger.info("=> loading source model from '{}'".format(args.source_net))
                checkpoint = torch.load(args.source_net, map_location=device)
                model_dict = model.state_dict()
                checkpoint_dict = checkpoint['state_dict']
                if args.exp_mode in ['score_prune', 'rate_prune', 'harp_prune']:
                    if args.dataset == 'imagenet' and args.gpu.find(',') == -1:
                        checkpoint_dict = {k.replace("module.", ""): v for k, v in checkpoint_dict.items()
                                           if k.find('popup_scores') == -1 and k.find('sub_block') == -1}
                    elif args.dataset != 'imagenet' and args.gpu.find(',') != -1:
                        checkpoint_dict = {f"module.{k}": v for k, v in checkpoint_dict.items()
                                           if k.find('popup_scores') == -1 and k.find('sub_block') == -1}
                    else:
                        checkpoint_dict = {k.replace("module.basic_model.", ""): v for k, v in checkpoint_dict.items()
                                           if k.find('popup_scores') == -1 and k.find('sub_block') == -1}
                    model_dict.update(checkpoint_dict)
                    model.load_state_dict(model_dict)
                else:
                    model.load_state_dict(checkpoint_dict)
                logger.info("=> loaded checkpoint '{}'".format(args.source_net))
            else:
                logger.info("=> no checkpoint found at '{}'".format(args.source_net))

            best_prec1 = 0

        # resume (if checkpoint provided). Continue training with previous settings.
        else:
            if os.path.isfile(args.resume):
                logger.info("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, map_location=device)
                args.start_epoch = checkpoint["epoch"]
                best_prec1 = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                logger.info(
                    f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']}, best_acc1 {best_prec1:.2f}%)"
                )
            else:
                logger.info("=> no checkpoint found at '{}'".format(args.resume))

    # init scores once source net is loaded.
    if args.scaled_score_init or args.exp_mode == "harp_finetune_lwm":
        initialize_scaled_score(model, args.prune_reg)

    if args.rate_stg_init:
        initialize_stg_rate(model, args, device, logger)
    elif args.exp_mode != 'pretrain':
        display_loadrate(model, logger, args)

    if args.prune_reg == 'channel':
        map_shortcut_rate(model, args, verbose=True)

    show_gradients(model, logger)

    if args.prune_reg == 'channel':
        _, _, start_rate = hw_flops_loss(model, device, optimizer, args, print_target=True)
    else:
        _, _, start_rate = hw_loss(model, device, optimizer, args, print_target=True)

    logger.info(f'\nStarting from: Prune-rate = {start_rate}\n')

    # Evaluate
    if args.evaluate or args.exp_mode != 'pretrain':
        if args.dataset == 'imagenet' and not args.evaluate:
            print('>> Skip initial evaluation!')
        else:
            p1_bn, _, p1, _, loss, adv_loss = val(model, device, test_loader, criterion, args, writer)
            logger.info(
                f"Benign validation accuracy {args.val_method} for source-net: {p1_bn}, Adversarial validation accuracy {args.val_method} for source-net: {p1}")
        if args.evaluate:
            return

    # Load current model state_dict for sanity check
    last_ckpt = copy.deepcopy(model.state_dict())
    
    # create s2ap proxy 
    if 's2ap' in args.trainer:
        proxy = copy.deepcopy(model)
        proxy_optim = get_optimizer(proxy, args) 
        proxy_lr_policy = get_lr_policy(args.lr_schedule)(proxy_optim, args)
        s2ap_adversary = TradesS2AP(model=model, 
                                  proxy=proxy, 
                                  proxy_optim=proxy_optim, 
                                  gamma=args.s2ap_gamma, 
                                  exp_mode=args.exp_mode,
                                  perturb_weights=args.perturb_weights, 
                                  freeze_bn=args.freeze_bn, 
                                  sparse_s2ap = args.sparse_s2ap, 
                                  k_s2ap=args.k_s2ap, 
                                  misalign_pert=args.misalign_pert) 

  
    else: 
        s2ap_adversary = None
    
    # Capture Loss, Adv Loss, Benign Acc & Adv Acc
    losses = []
    adv_losses = []
    acc_ben = []
    acc_adv = []

    # Start training

    frozen_gamma = None
    reach_hw = False

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs + args.warmup_epochs):
        # save mask for lwm
        if args.save_mask:
            m_dict = model.state_dict()
            is_best = True
            best_prec1 = 100.0
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": m_dict,
                    "best_prec1": best_prec1,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                args,
                result_dir=os.path.join(args.save_mask_path, "checkpoint"),
                save_dense=args.save_dense,
            )
            exit(0)

        # adjust learning rate
        lr_policy(epoch)  

        # adjust s2ap learning rate 
        if args.trainer=='s2ap': 
            proxy_lr_policy(epoch)
        
        # train
        trainer(
            model,
            device,
            train_loader, 
            criterion,
            optimizer,
            epoch,
            args,
            writer,
            frozen_gamma=frozen_gamma, 
            s2ap_adversary=s2ap_adversary,
        )
        
        # evaluate on test set
        prec1_benign, _, prec1, _, loss, adv_loss = val(model, device, test_loader, criterion, args, writer, epoch)
        losses.append(loss)
        adv_losses.append(adv_loss)
        acc_ben.append(prec1_benign.item())
        acc_adv.append(prec1.item())

        # Check current compression rate
        hw_info = ''
        if args.exp_mode != 'pretrain':
            if args.prune_reg == 'channel':
                loss_hw_func = hw_flops_loss
            else:
                loss_hw_func = hw_loss
            gamma, loss_hw, current_rate = loss_hw_func(model, device, optimizer, args, epoch=epoch, frozen_gamma=frozen_gamma)

            if np.round(loss_hw.cpu().data, 4) == 0.0:
                frozen_gamma = gamma
                reach_hw = True

            hw_info = f"gamma {gamma:.4f}, hw-loss {loss_hw:.4f}, compress-rate {current_rate:.4f}, "

        # remember best prec@1 and save checkpoint
        if args.exp_mode in ['rate_prune', 'harp_prune']:
            if reach_hw:
                is_best = prec1 > best_prec1 if args.adv_loss != 'nat' else prec1_benign > best_prec1
                best_prec1 = max(prec1, best_prec1) if args.adv_loss != 'nat' else max(prec1_benign, best_prec1)
            else:
                is_best = False
        else:
            is_best = prec1 > best_prec1 if args.adv_loss != 'nat' else prec1_benign > best_prec1
            best_prec1 = max(prec1, best_prec1) if args.adv_loss != 'nat' else max(prec1_benign, best_prec1)


        # save mask for hamming index computation 
        if args.hamming: 
            # extract the flat mask
            flat_mask = extract_masks(model, args)
            # save as a .pth (binary) or convert to list for JSON
            mask_path = os.path.join(
                result_sub_dir,
                f"masks_epoch_{epoch+1}.pth"
        )
            torch.save(flat_mask, mask_path)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": m_dict,
                "best_prec1": best_prec1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            args,
            result_dir=os.path.join(result_sub_dir, "checkpoint"),
            save_dense=args.save_dense,
        )

        best_acc_name = 'best_adv'  # if args.adv_loss != 'nat' else 'best_benign'
        if args.dataset == 'imagenet':
            acc_info = f"adversarial valid-acc {prec1:.4f}, {best_acc_name} {best_prec1:.4f}"
        else:
            acc_info = f"benign valid-acc {prec1_benign:.4f}, adversarial valid-acc {prec1:.4f}, {best_acc_name} {best_prec1:.4f}"

        epoch_info = f"Epoch {epoch}, val-method {args.val_method}, " + hw_info + acc_info
        if is_best:
            epoch_info += f" [update BEST]"

        logger.info(epoch_info)

        if args.exp_mode in ['rate_prune', 'harp_prune']:
            display_loadrate(model, logger, args)

        if args.exp_mode in ["score_prune", "score_finetune"]:
            logger.info(
                "Pruned model: {:.2f}%".format(
                    current_model_pruned_fraction(
                        model, os.path.join(result_sub_dir, "checkpoint"), verbose=False
                    )
                )
            )
        # clone results to latest subdir (sync after every epoch)
        # Latest_subdir: stores results from latest run of an experiment.
        clone_results_to_latest_subdir(
            result_sub_dir, os.path.join(result_main_dir, "latest_exp")
        )

        # Check what parameters got updated in the current epoch.
        sw, ss, sr = sanity_check_paramter_updates(model, last_ckpt)
        logger.info(
            f"Sanity check (exp-mode: {args.exp_mode}): Weight update - {sw}, Scores update - {ss}, Rates update - {sr}"
        )

        print(f"Time since start of training: {float(time.time() - start_time) / 60} minutes")

    end_time = time.time()
    print(
        f"Total training time: {end_time - start_time} seconds. These are {float((end_time - start_time) / 3600)} hours")


if __name__ == "__main__":
    main()
