import os
import sys
import yaml
import random
import logging
import argparse
from copy import deepcopy
from contextlib import suppress

import mlflow
import numpy as np
from rich import print as pp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim

from utils.sam import SAM
from dataset import DataModule
from models.network import get_network
from utils.loss import SoftTargetCrossEntropy, LabelSmoothingCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
from utils.dir_maker import DirectroyMaker
from utils.get_scheduler import get_scheduler
from utils.mixup import Mixup
from utils.transmix import Mixup_transmix
from arch import Model
import utils.etc as etc

import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.distributed import rank_zero_only
from torchmetrics.functional import accuracy

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
logging.getLogger('alembic.runtime.migration').disabled = True
sys.excepthook = etc.handle_exception


config_parser = parser = argparse.ArgumentParser(description='GTA: Guided Transfer of Spatial Attention from Self-supervised Models')
parser.add_argument('--seed', default=234, type=int, help='seed')
parser.add_argument('--config', default=None, type=str,
                    help='Config Yaml File to load')
parser.add_argument('--guide', default=False, type=bool,
                    help='Use GTA: Guided Transfer of Spatial Attention from Self-supervised Models')
parser.add_argument('--glambda', default=1., type=float,
                    help='Guide Reg Lambda')
parser.add_argument('--l2sp', action='store_true', default=False,
                    help='L2-SP')
parser.add_argument('--trans_fg', action='store_true', default=False,
                    help='transfg')
parser.add_argument('--bss',  action='store_true', default=False,
                    help='BatchSpectralShrinkage')
parser.add_argument('--attn_only', action='store_true', default=False,
                    help='Train with Attention only')
parser.add_argument('--mlp_only', action='store_true', default=False,
                    help='Feed forward Network only')
parser.add_argument('--feature_kd', action='store_true', default=False,
                    help='Feature based KD')
parser.add_argument('--msa_kd', action='store_true', default=False,
                    help='MSA based KD')

parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--sched', default='step', type=str, help='scheduler')
parser.add_argument('--max_lr', default=1e-3, type=float,
                    help='onecycle LR maxLR')
parser.add_argument('--min_lr', default=1e-5, type=float,
                    help='onecycle and cosine_timm LR minLR')
parser.add_argument('--warmup_lr', default=1e-6, type=float,
                    help='warmup lr initializer for cosine_timm')
parser.add_argument('--pct_start', default=1e-2, type=float,
                    help='percent to warmup')
parser.add_argument('--decay_t', default=4, type=int,
                    help="it considers the step lr decrease points"
                    "100 epoch with 5 decay_t"
                    "it will decrease at 20, 40, 60, 80, 100")
parser.add_argument('--lr', default=0.1, type=float,
                    help="initial learning rate")
parser.add_argument('--lr_cycle_decay', default=0.5, type=float,
                    help="lr cycle decay for Cosine")
parser.add_argument('--lr_decay_rate', default=0.1, type=float,
                    help='learning rate decay rate')
parser.add_argument('--weight_decay', default=5e-4, type=float,
                    help='weight_decay')
parser.add_argument('--momentum', default=0.9, type=float,
                    help='momentum')

parser.add_argument('--fixed_RA_iter', default=False, type=bool,
                    help='fixed iteration mode')
parser.add_argument('--fixed_iter', default=10, type=int,
                    help='how many iters you need for one epoch')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number')
parser.add_argument('--end_epoch', default=300, type=int,
                    help='number of training epoch to run')
parser.add_argument('--cooldown_epochs', default=10, type=int,
                    help='cooldown epochs')
parser.add_argument('--warmup_epochs', default=20, type=int,
                    help='warmup epochs for cosine timm')
parser.add_argument('--batch_size', type=int, default=128,
                    help='mini-batch size (default: 128), this is the total'
                    'batch size of all GPUs on the current node when'
                    'using Data Parallel or Distributed Data Parallel')

parser.add_argument('--amp', type=bool, default=False,
                    help='set amp')
parser.add_argument('--classifier_type', type=str, 
                    default='deit_small_patch16_224_return_total_attn')
parser.add_argument('--initial_checkpoint', default=None,
                    type=str, metavar='PATH',
                    help='Initialize model from this checkpoint default: none')
parser.add_argument('--teacher_checkpoint', default=None,
                    type=str, metavar='PATH',
                    help='Initialize Teacher model from this checkpoint, default: none')
parser.add_argument('--is_pretrained_imagenet', action='store_true',
                    help='Initialize model from imagenet pretrained weight')
parser.add_argument('--patch_size', default=16, type=int, help=
                    """Size in pixels
                    of input square patches - default 16 (for 16x16 patches).
                    Using smaller values leads to better performance
                    but requires more memory. Applies only
                    for ViTs (vit_tiny, vit_small and vit_base).""")
parser.add_argument('--drop_path_rate', type=float, default=0.1,
                    help="stochastic depth rate")
parser.add_argument('--num_classes', type=int, default=1000, metavar='N',
                    help='number of label classes')

parser.add_argument('--tag', type=str)
parser.add_argument('--experiments_dir', type=str,
                    default='output',
                    help='Directory name to save the model, log, config')
parser.add_argument('--experiments_name', type=str,
                    default=None, help='Experiment name under experiments_dir')
parser.add_argument('--experiments_subname', type=str, default='',
                    help='this is for Nested Run')

parser.add_argument('--data_path', type=str, default=None,
                    help='download dataset path')
parser.add_argument('--dataset_name', type=str, default=None,
                    help='name of dataset')
parser.add_argument('--sample_rate', type=int, default=100,
                    help='StanforCars, CUB200, Aircraft Sampling rate in int ex 15, 30')
parser.add_argument('--train_split', metavar='NAME', default='train',
                    help='dataset train split (default: train)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                    help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                    help='Override std deviation of of dataset')

parser.add_argument('--aug', type=str, default=None,
                    help='Augmentation Schedule')
parser.add_argument('--repeated_aug', type=bool, default=False,
                    help='Use Repeated Augmentation')
parser.add_argument('--color_jitter', type=float, default=None,
                    help='Color Jittering')
parser.add_argument('--transmix', action='store_true', default=False,
                    help='Transmix if set (default: False')
parser.add_argument('--mixup', type=float, default=0.0,
                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.0,
                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                    help='cutmix min/max ratio, overrides alpha and enables'
                    'cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
                    help='Probability of performing mixup'
                    'or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                    help='Probability of switching to cutmix'
                    'when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
                    help='How to apply mixup/cutmix params.'
                    'Per "batch", "pair", or "elem"')
parser.add_argument('--smoothing', type=float, default=0.1,
                    help='Label smoothing (default: 0.1)')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                    help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
                    help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
                    help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
                    help='Do not random erase first (clean) augmentation split')

parser.add_argument('--saveckp_freq', default=299, type=int,
                    help='Save checkpoint every x epochs. Last model saving set to 299')
parser.add_argument('--resume', type=str, default=None, help='load model path')

parser.add_argument('--workers', default=16, type=int,
                    help='number of workers for dataloader')


def fix_seed(random_seed):
    pp(f"Fix Seed : {random_seed}")
    seed_everything(random_seed, workers=True)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    cudnn.deterministic = True
    cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


def _parse_args():
    args_config, _ = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as file:
            cfg = yaml.safe_load(file)
            parser.set_defaults(**cfg)
    arg = parser.parse_args()
    return arg


@rank_zero_only
def save(net):
    mlflow.pytorch.log_model(
        net, "model_last_epoch",
        registered_model_name=args.experiments_name)


if __name__ == "__main__":
    args = _parse_args()
    args.mixup_active = args.mixup > 0 or args.cutmix > 0.
    if args.mean is None:
        args.mean = (0.485, 0.456, 0.406)
    if args.std is None:
        args.std = (0.229, 0.224, 0.225)
    linear_scaled_lr = args.lr * args.batch_size * args.gpus / 512.0
    args.lr = linear_scaled_lr
    strategy = DDPStrategy(find_unused_parameters=True) if args.gpus > 1 else None
    tqdm_callback = TQDMProgressBar()
    lr_checker = LearningRateMonitor(logging_interval='epoch')
    trainer = pl.Trainer(
        max_epochs=(args.end_epoch + args.cooldown_epochs),
        accelerator='gpu',
        devices=args.gpus,
        strategy=strategy,
        callbacks=[tqdm_callback, lr_checker],
        log_every_n_steps=1,
        deterministic=True,
        replace_sampler_ddp=not args.repeated_aug,
        amp_backend='native',
        precision=16 if args.amp else 32
    )
    if trainer.global_rank == 0:
        pp(f"GPU : {args.gpus}, DDP : {strategy is not None}")
        pp(f"[blue] learning rate will be {args.lr} [/blue]")
    else:
        pp(f"[blue] {trainer.global_rank} is ready [/blue]") 
    
    sup = mlflow.start_run(run_name=args.tag) if trainer.global_rank == 0 else suppress()
    with sup as run:
        if trainer.global_rank == 0:
            mlf_logger = pl.loggers.MLFlowLogger(
                run_id = run.info.run_id,
                run_name=args.tag)
            trainer.logger = mlf_logger
            trainer.logger.log_hyperparams(args)
        fix_seed(args.seed + trainer.global_rank)
        model = Model(args)
        dm = DataModule(args)
        fix_seed(args.seed + trainer.global_rank)
        trainer.fit(model, datamodule=dm)
        save(model.model)
    torch.cuda.empty_cache()
