import os
import time
import argparse
import datetime
import numpy as np
import wandb

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter

from config import get_config
from models import build_model, build_CoMM_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from util import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
import pdb
try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

import sys
sys.path.insert(0, '../src/models')


def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--task-id', type=int, help="task_id")
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--weight-decay', type=float, help="")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--pretrained', type=str, help='path to pre-trained model')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')
    parser.add_argument('--linear-probe', action='store_true', help='linear probe for validation')
    parser.add_argument("--lr-multiplier", type=float, default=1, help='multiplier for base/warmup/min_lr')
    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
    
    args = parser.parse_args()
    config = get_config(args)
    return args, config


def main(config, _info):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger, is_pretrain=False)
    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    # model = build_model(config, is_pretrain=False)
    model = build_CoMM_model(config, is_pretrain=False)
    model.cuda()        
    load_pretrained(config, model, logger)
    logger.info("Start to extract attn features")    
    pass_one_epoch(config, model, data_loader_val, _info)
                    
@torch.no_grad()
def pass_one_epoch(config, model, data_loader, _info):
    model.eval()    
    # logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
    if hasattr(config.MODEL, 'VIT'):
        depth = config.MODEL.VIT.DEPTH
        logger.info(f'VIT DEPTH {depth}')
    else:
        depth = sum(config.MODEL.SWIN.DEPTH)
        logger.info(f'SWIN DEPTH {config.MODEL.SWIN.DEPTH}, {depth} in total.')
    full_attn_meter = [AverageMeter() for _ in range(depth)]
    base_attn_meter = [AverageMeter() for _ in range(depth)]
        
    for didx, (samples, _) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        # _, attns = model(samples, return_attns=True)
        _, full_attns, base_attns = model(samples, return_attns=True)
        # attns = reduce_tensor(attns)
        
        for idx, fattn in enumerate(full_attns):
            # attn = reduce_tensor(attn)
            if config.MODEL.TYPE != 'vit':
                fattn = fattn.sum(0, keepdim=True)
            full_attn_meter[idx].update(fattn.cpu().data)        
        
        for idx, battn in enumerate(base_attns):
            # attn = reduce_tensor(attn)
            if config.MODEL.TYPE != 'vit':
                battn = battn.sum(0, keepdim=True)
            base_attn_meter[idx].update(battn.cpu().data)        
        
        if (didx+1) % int(len(data_loader)/10) == 0:
            logger.info(f'attn progress {didx+1}/{len(data_loader)}')
    
    full_dict = {faidx:fatt.sum for faidx, fatt in enumerate(full_attn_meter)}    
    base_dict = {baidx:batt.sum for baidx, batt in enumerate(base_attn_meter)}    
    
    try:
        torch.save(full_dict, os.path.join(config.OUTPUT, f'prtt_full_attt{config.DATA.TASK_ID}.pt'))            
    except:
        torch.save(full_dict, os.path.join(config.OUTPUT, f'noprtt_full_attt{config.DATA.TASK_ID}.pt'))  
    
    try:
        torch.save(base_dict, os.path.join(config.OUTPUT, f'prtt_base_attt{config.DATA.TASK_ID}.pt'))    
    except:
        torch.save(base_dict, os.path.join(config.OUTPUT, f'noprtt_base_attt{config.DATA.TASK_ID}.pt'))    
        
if __name__ == '__main__':
    _, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(config.LOCAL_RANK)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    _info_split = config.PRETRAINED.replace('/d1/jaehong/laps_finetune_cache/simmim_finetune/','').split('/')    
    _info = os.path.join(_info_split[0], _info_split[1], _info_split[2], _info_split[5])    
    _info = _info.replace('.pth','')
    config.defrost()
    config.OUTPUT = f'../understand_mim/anal0928_mae_conft_adapts/{_info}'
    config.freeze()
    
    
    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")    
    try:
        logger.info(f'pretrained pth from task{_info[1]} and target dataset for attn from task{config.DATA.TASK_ID}')
    except:
        logger.info(f'Not pretrained and target dataset for attn from task{config.DATA.TASK_ID}')
    main(config, _info)
