import torch
from timm.models import create_model
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
import time, datetime, os, sys, random, numpy as np
from datasets import build_continual_dataloader
import vits.hide_prompt_vision_transformer as hide_prompt_vision_transformer
import sys
sys.path.append('/public/home/lichang/projects/cl/cl/ACL')
import optimgrad

# def create_optimizer_2(args, model, is_first_epoch=False):
#     high_lr_params = []
#     low_lr_params = []

#     for name, param in model.named_parameters():
#         if not param.requires_grad:
#             continue
#         if 'cls_head' in name or 'layer_norm' in name:
#             high_lr_params.append(param)
#         else:
#             low_lr_params.append(param)

#     optimizer = torch.optim.AdamW([
#         {'params': high_lr_params, 'lr': args.lr, 'weight_decay': args.weight_decay},
#         {'params': low_lr_params,  'lr': args.lr , 'weight_decay': args.weight_decay},
#     ], betas=args.opt_betas, eps=args.opt_eps)

#     return optimizer

def create_optimizer_2(args, model, is_first_epoch=True):
    high_lr_params = []
    low_lr_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if 'cls_head' in name or 'layer_norm' in name:
            high_lr_params.append(param)
        else:
            low_lr_params.append(param)

    # optimizer = torch.optim.AdamW([
    #     {'params': high_lr_params, 'lr': args.lr * 0.2, 'weight_decay': args.weight_decay},
    #     {'params': low_lr_params,  'lr': args.lr , 'weight_decay': args.weight_decay},
    # ], betas=args.opt_betas, eps=args.opt_eps)
    if is_first_epoch == True:
        lr = args.lr
    else:
        lr = args.lr_for_other_tasks
    model_optimizer_arg = {'params': [{'params': high_lr_params, 'lr': lr * args.learning_rate_ratio_for_cls_head, 'weight_decay': args.weight_decay},
                                      {'params': low_lr_params,  'lr': lr , 'weight_decay': args.weight_decay, 'svd': True, 'thres': args.svd_ratio}],
                            'betas': args.opt_betas,
                            'eps' : args.opt_eps
                          }
    optimizer = getattr(optimgrad, 'Adam')(**model_optimizer_arg)
    # optimizer = torch.optim.AdamW(**model_optimizer_arg)

    return optimizer

def train(args):
    device = torch.device(args.device)
    data_loader, data_loader_per_cls, class_mask, target_task_map = build_continual_dataloader(args)

    print('loading pretrained fairseq model')
    import sys
    sys.path.append('/public/home/lichang/projects/cl/cl/ACL/EAT/models')
    from EAT_pretraining import Data2VecMultiModel_2_ranpac, Data2VecMultiModel_2_ranpac_lora, Modality
    pretrained_ssl_dir = '/public/home/lichang/projects/cl/cl/ACL/EAT/EAT-base_epoch30_pt.pt'
    # pretrained_ssl_dir = '/public/home/lichang/projects/cl/cl/ACL/checkpoint_last.pt'
    ckpt = torch.load(pretrained_ssl_dir, map_location="cpu")
    cfg = ckpt["cfg"]
    state_dict = ckpt["model"]
    tmp_modality = cfg['model']['supported_modality']
    from omegaconf import OmegaConf
    model_cfg = OmegaConf.create(cfg['model'])
    model_cfg.type = Modality[tmp_modality]
    model_cfg.max_length = 768
    setattr(model_cfg, "lora_depth", args.lora_depth)
    setattr(model_cfg, "lora_rank", args.lora_rank)
    model = Data2VecMultiModel_2_ranpac_lora(model_cfg, [Modality[tmp_modality]], skip_ema=True, task=None, rp_dim=args.rp_dim, dataset=args.dataset, args=args)

    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

    print('loading pretrained fairseq model')
    # import pdb; pdb.set_trace()

    model.to(device)
    # if args.prompt_type == 'continual':
    #     from engines.continual_pet_engine import train_and_evaluate, evaluate_till_now
    # if args.prompt_type == 'momentum':
    #     from engines.momentum_pet_engine import train_and_evaluate, evaluate_till_now
    from engines.ranpac_engine import train_and_evaluate
    
    for name, param in model.named_parameters():
        pass
        # print(name)
        if 'layer_norm' not in name and 'cls_head' not in name and 'lora' not in name and 'adapter' not in name and 'prompt' not in name:
            param.requires_grad = False 
        else:
            param.requires_grad = True 
        # if 'modality_encoders' in name or 'blocks' in name:
        #     param.requires_grad = False 
        # else:
        #     param.requires_grad = True 
        # if  'blocks.11' in name:
        #     param.requires_grad = True 
    # import pdb; pdb.set_trace()

    # if args.freeze:
    #     # freeze args.freeze[blocks, patch_embed, cls_token] parameters
    #     for n, p in model.named_parameters():
    #         if n.startswith(tuple(args.freeze)):
    #             p.requires_grad = False

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_memory_bytes = sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad)
    total_memory_MB = total_memory_bytes / (1024 ** 2)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    n_parameters = sum(p.numel() for p in trainable_params)
    for name, param in model.named_parameters():
        if param.requires_grad:
            param_size_MB = param.numel() * param.element_size() / (1024 ** 2)
            print(f'Parameter name: {name}, Shape: {param.shape}, Memory: {param_size_MB:.2f} MB')
    print('total memory (MB):', total_memory_MB)
    print('number of params:', n_parameters)

    if args.unscale_lr:
        global_batch_size = args.batch_size
    else:
        global_batch_size = args.batch_size * args.world_size
    args.lr = args.lr * global_batch_size / 256.0

    optimizer = create_optimizer_2(args, model_without_ddp, is_first_epoch=True)
    if args.sched != 'constant':
        lr_scheduler, _ = create_scheduler(args, optimizer)
    elif args.sched == 'constant':
        lr_scheduler = None

    criterion = torch.nn.CrossEntropyLoss().to(device)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()

    train_and_evaluate(model, model_without_ddp, criterion, data_loader, data_loader_per_cls, optimizer, lr_scheduler, device, class_mask, target_task_map, args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Total training time: {total_time_str}")
