from __future__ import print_function
import copy
import os
import pickle
from re import A

import cv2
import json

#from pkg_resources import evaluate_marker
from timm.models.dla import dla34
from tqdm import tqdm
import time
import random
import wandb
import numpy as np
import termplotlib as tpl

from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler 
from torch.utils.data import DataLoader

import clip.clip as clip
from open_clip.zero_shot_classifier import get_tokens
from open_clip.zero_shot_metadata import IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES

from utils.utils import save_checkpoint, cosine_lr, convert_models_to_fp32, refine_classname, get_model, set_seed
from utils.lora import convert2lora, mark_only_lora_as_trainable, mark_only_lora_as_not_trainable, Conv2d
from utils.data_utils import get_dataset

from lib.dataset.imagenetr_utils import imagenet_r_transform, imagenet_r_mask, reverse_imagenet_r_mask, imagenet_a_mask
from lib.dataset.objectnet_dataset import objectnet_mask
from lib.regularization import LearningWithoutForgetting, ewc_penalty
from lib.argument import parse_option
from lib.experiment import train, validate
from lib.evaluate import evaluations
from lib.prompter import PadPrompter
from lib.cka import calculate_cka
from open_clip.zero_shot_classifier import build_zero_shot_classifier


best_acc1 = 0
device = "cuda" if torch.cuda.is_available() else "cpu"

def get_soup(args, model):
    # PRE, FT, EWC, LwF
    reg = args.regularization
    model_names = reg.split('-')[1:]
    params = []

    for model_name in model_names:
        if 'PRE' in model_name:
            params.append({'state_dict': model.state_dict()})
        else:
            # check whether path exists
            _path = args.filename.replace(reg, model_name)

            if 'wise' in _path:
                _path = _path.split('_wise')[0]
            path = os.path.join(args.model_dir, _path, 'checkpoint.pth.tar')
            if not os.path.exists(path):
                print(path, "Not Exists")
                raise Exception
            print("Load ", path)
            try:
                weight = torch.load(path)
            except:
                weight = torch.load(path, weights_only=False)
            params.append(weight) 
            if params[-1]['epoch'] != args.epochs and not args.eval_best: # check fully-trained?
                print("Need to run more")
                raise Exception()
            if args.eval_best:
                path = os.path.join(args.model_dir, _path, 'model_best.pth.tar')
                # pop params and load best model.
                params.pop()
                try:
                    weight = torch.load(path)
                except:
                    weight = torch.load(path, weights_only=False)
                params.append(weight) 

    if reg == 'Soup-PRE-FT' and args.wise_ratio != 0.5:
        for key in params[0]['state_dict'].keys():
            params[0]['state_dict'][key] = args.wise_ratio * params[0]['state_dict'][key].cuda() + (1-args.wise_ratio) * params[1]['state_dict'][key].cuda()
    else:
        # average all parameters in params to load.
        for key in params[0]['state_dict'].keys():
            params[0]['state_dict'][key] = params[0]['state_dict'][key].cuda()
            for i in range(1, len(params)):
                params[0]['state_dict'][key] += params[i]['state_dict'][key].cuda()
            params[0]['state_dict'][key] = params[0]['state_dict'][key] / len(params)
    model.load_state_dict(params[0]['state_dict'])

    model.eval()
    return model

def main():
    global best_acc1, device

    args = parse_option()
    print(args)

    if args.seed is not None:
        set_seed(args.seed)
    num_classes = 200   
    model, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg=args.regularization, mode=args.mode)
    if args.dataset == 'imagenet' and False:
        if args.regularization == 'PRE':
            # read zeroshot weights.
            print("Load Text Embedding for Zeroshot")
        else:
            model.fc = nn.Linear(model.head.in_features, 1000, bias=True)
            model = model.cuda()
    if args.mode != 'contrastive' and (args.d_pre == '400m' or args.d_pre == '100m' or args.d_pre == 'datacomp' or args.d_pre == 'none') and args.regularization != 'FLYP':
        model.fc = nn.Linear(512, 1000, bias=True)
        if torch.cuda.is_available:
            model = model.cuda()
    if args.d_pre == 'in1k_orig':
        model = model.cuda()
    # Use multi-tokenizer
#    model.add_multi_tokenizer()

    # Use LoRA
    if args.regularization == 'lora':
        if 'siglip' in args.model:
            ATTN_SUBSTRS = ["q_proj","k_proj","v_proj","o_proj","query","key","value","out","out_proj"]
            for p in model.parameters():
                p.requires_grad = False
            from peft import LoraConfig, get_peft_model
            lora_text = LoraConfig(
                r=8, lora_alpha=1, lora_dropout=0.00, bias="none",
                task_type="FEATURE_EXTRACTION",
                target_modules=ATTN_SUBSTRS
            )
            model_lora = get_peft_model(model, lora_text)
        elif args.d_pre == 'in1k_orig':
            mark_only_lora_as_trainable(model)
        else:
#        lora_conv = Conv2d(3, 768, kernel_size=16, stride=16, bias=True, r=8).cuda()
#        lora_conv.conv = copy.deepcopy(model.patch_embed.proj)
#        model.patch_embed.proj = lora_conv
            model = convert2lora(model)
            if hasattr(model, 'blocks'):
                mark_only_lora_as_trainable(model.blocks)
        #        for name, param in model.blocks.named_parameters():
        #            param.requires_grad = False
                for name, param in model.patch_embed.named_parameters():
                    param.requires_grad = False
                for name, param in model.head.named_parameters():
                    param.requires_grad = False
            elif hasattr(model, 'visual'):
                mark_only_lora_as_trainable(model) #.visual.transformer.resblocks)
    elif args.regularization == 'HeadOnly':
        if args.model == 'vit':
            if args.d_pre == 'in1k_orig':
                for name, param in model.named_parameters():
                    param.requires_grad = False
                for name, param in model.heads.named_parameters():
                    param.requires_grad = True
            else:
                for name, param in model.patch_embed.named_parameters():
                    param.requires_grad = False
                for name, param in model.blocks.named_parameters():
                    param.requires_grad = False
        elif args.model == 'dinov2':
            for name, param in model.named_parameters():
                param.requires_grad = False
            for name, param in model.classifier.named_parameters():
                param.requires_grad = True
        else: # resnet
            for name, param in model.named_parameters():
                param.requires_grad = False
            for name, param in model.fc.named_parameters():
                param.requires_grad = True
    elif args.regularization == 'ConvHeadOnly':
        for name, param in model.blocks.named_parameters():
            param.requires_grad = False
    elif args.regularization == 'Prompter' or args.regularization == 'PrompterV2':
        model.prompter = PadPrompter(args).to(device)

    if False: # fix norm
        for name, param in model.named_parameters():
            if 'norm' in name:
                param.requires_grad = False
    print(model)

    if args.load is not None:
#        state_dict = torch.load(args.load)
        state_dict = torch.load(args.load, weights_only=False)
        if 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
        del state_dict
    
    convert_models_to_fp32(model)
    model.eval()

    # define criterion and optimizer
    lr_scaling = None
    if args.regularization == 'Prompter':
        target_param = model.prompter.parameters()
    elif args.regularization == 'PrompterV2':
        target_param = list(model.prompter.parameters()) + list(model.head.parameters())
    elif args.fix_norm_pre:
        target_param = dict(model.named_parameters())
        target_param.pop('norm_pre.weight')
        target_param.pop('norm_pre.bias')
        target_param = list(target_param.values())
    elif args.change_lr_target is not None:
        target_param1 = [e[1] for e in list(model.named_parameters()) if e[0] not in args.change_lr_target]
        target_param2 = [e[1] for e in list(model.named_parameters()) if e[0] in args.change_lr_target]
        target_param1 = {'params': target_param1,'lr': args.learning_rate}
        target_param2 = {'params': target_param2,
                        'lr': args.learning_rate * args.change_lr}
        target_param = [target_param1, target_param2]
    if args.discard_lr_target is not None:
        target_param = dict(model.named_parameters())
        keys = list(target_param.keys())
        max_index = -1
        for dlt in args.discard_lr_target:
            if dlt in keys:
                max_index = max(keys.index(dlt), max_index)
        for i in range(max_index+1):
            target_param.pop(keys[i])
        target_param = list(target_param.values())
    else:
        target_param = model.parameters()

    
    if False:
        #fix norm
        target_param = [e[1] for e in list(model.named_parameters()) if 'norm' not in e[0]]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(target_param,
    #    optimizer = torch.optim.SGD(model.alt_patch_embed.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optim == 'adamw':
        optimizer = torch.optim.AdamW(target_param, lr=args.learning_rate, weight_decay=args.weight_decay)
    model.drop_option = args.drop_option

    # optionally resume from a checkpoint
    args.model_folder = os.path.join(args.model_dir, args.filename)
    args.start_epoch = 0
    if not args.resume:
        args.resume = os.path.join(args.model_folder, 'checkpoint.pth.tar')
    if args.resume:
        if os.path.isfile(args.resume):
            if args.gpu is None:
                try:
                    checkpoint = torch.load(args.resume)
                except:
                    checkpoint = torch.load(args.resume, weights_only=False)

                print("load", args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                try:
                    checkpoint = torch.load(args.resume, map_location=loc)
                except:
                    checkpoint = torch.load(args.resume, map_location=loc, weights_only=False)

            out = model.load_state_dict(checkpoint['state_dict'], strict=False)
            print(out)
        else: 
            args.resume = os.path.join(args.model_folder, 'checkpoint.pth.tar')
            for e in reversed(range(1, args.epochs)):
                args.resume = os.path.join(args.model_folder.replace(f'E{args.epochs}', f'E{e}'), 'checkpoint.pth.tar')
                if os.path.isfile(args.resume):
                    break
            for e in reversed(range(1, args.epochs+1)):
                resume = os.path.join(args.model_folder, f'checkpoint_{e}.pth.tar')
                if os.path.isfile(resume):
                    args.resume = resume
                    break
            if os.path.isfile(args.resume):
                resume = None
                print("=> loading checkpoint '{}'".format(args.resume))
                if args.gpu is None:
                    print("Load", args.resume)
                    try:
                        checkpoint = torch.load(args.resume)
                    except:
                        checkpoint = torch.load(args.resume, weights_only=False)
                else:
                    # Map model to be loaded to specified single gpu.
                    loc = 'cuda:{}'.format(args.gpu)
                    print("Load", args.resume)
                    try:
                        checkpoint = torch.load(args.resume, map_location=loc)
                    except:
                        checkpoint = torch.load(args.resume, map_location=loc, weights_only=False)
                
                try:
                    model.load_state_dict(checkpoint['state_dict']) #, strict=False)
                    optimizer.load_state_dict(checkpoint['optimizer'])

                    args.start_epoch = checkpoint['epoch']
                    best_acc1 = checkpoint.get('best_acc1', 0)
                    if args.gpu is not None:
                        # best_acc1 may be from a checkpoint from a different GPU
                        best_acc1 = best_acc1.to(args.gpu)
                    print("=> loaded checkpoint '{}' (epoch {})"
                          .format(args.resume, checkpoint['epoch']))
                    # change the step.
                except:
                    print("=> FAILED TO loaded checkpoint '{}' (epoch {})"
                          .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

    # create data
    template = 'This is a photo of a {}'
    print(f'template: {template}')
    train_dataset, val_dataset, test_dataset = get_dataset(args, preprocess)


    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size, pin_memory=True,
                              num_workers=args.num_workers, shuffle=True)

    val_loader = DataLoader(val_dataset,
                            batch_size=min(1024, 2*args.batch_size), pin_memory=True,
                            num_workers=args.num_workers, shuffle=False)

    if args.mode == 'contrastive' or args.regularization == 'FLYP':
        if 'siglip' in args.model:
            templates = ['a photo of a {}']
            texts = [template.format(c) for c in IMAGENET_CLASSNAMES for template in templates]
            texts = tokenizer(texts, padding='max_length', max_length=64, return_tensors='pt').to(device)['input_ids']
        else:
            texts = get_tokens(tokenizer, IMAGENET_CLASSNAMES, device=device)
    else:
        texts = None

    criterion = torch.nn.CrossEntropyLoss().to(device)
    scaler = GradScaler()
    total_steps = len(train_loader) * args.epochs
    
    if args.scheduler == 'cosine':
        if args.no_cosine:
            scheduler = None
        scheduler = cosine_lr(optimizer, args.learning_rate, args.warmup, total_steps)
    else:
        milestones = [int(e) for e in args.scheduler.split('_')[1:]]
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    cudnn.benchmark = True

    if not os.path.isdir(args.model_folder):
        os.makedirs(args.model_folder, exist_ok=True)

    # wandb
    if args.use_wandb:
        wandb.init(project=args.project_name,
                   config=args,
                   name=args.filename,
                   dir=args.model_folder)
    
    test_acc1 = 0

    if 'Soup' in args.regularization:
        model = get_soup(args, model)
    if args.regularization == 'LPFT':
        # load FT model.
        reg = args.regularization
        is_loaded = False
        # check whether path exists
        _path = args.filename.replace(reg, 'HeadOnly')
        if args.eval_best:
            path = os.path.join(args.model_dir, _path, 'model_best.pth.tar')
        else:
            path = os.path.join(args.model_dir, _path, 'checkpoint.pth.tar')
        if not os.path.exists(path):
            print(path, "Not Exists")
            raise Exception
        print("Load ", path)
        try:
            data = torch.load(path)
        except:
            data = torch.load(path, weights_only=False)
        if data['epoch'] == args.epochs or args.eval_best: # check fully-trained?
            model.load_state_dict(data['state_dict'])
            is_loaded = True
        if not is_loaded:
            raise Exception("Not fully trained model")

    if args.regularization == 'PRE' or 'Soup' in args.regularization:
#        pre_model, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg=args.regularization, mode=args.mode)
#        cka_results = calculate_cka(pre_model, model, preprocess, cka_path, args)
        evaluations(args, model, preprocess, texts, tokenizer, criterion)
        if args.use_wandb:
            wandb.run.finish()
        exit()

    imagenets = ['imagenet-r', 'imagenet-a', 'objectnet-v2']
    masks = [imagenet_r_mask, imagenet_a_mask, objectnet_mask]
    dataset2mask = {imagenets[i]: masks[i] for i in range(len(imagenets))}
    mask = dataset2mask.get(args.dataset)
    if '+' in args.dataset:
        masks = []
        for dataset in args.dataset.split('+'):
            _mask = dataset2mask.get(dataset)
            if _mask is None:
                masks = []
                break
            masks.append(np.asarray(_mask))
        if len(masks) > 1:
            mask = np.logical_or.reduce(masks).tolist()
        else:
            mask = None

    old_model = copy.deepcopy(model)  # evaluate the difference between shared tokens and unique tokens using the old model.

    # Get regularizer
    if args.regularization== 'lwf':
        lwf = LearningWithoutForgetting()
        lwf.prev_model = old_model
        lwf.mask = mask
        regularizer = lambda images, outputs, model, texts=None: lwf(images, outputs, mb_text=texts)
    elif args.regularization == 'ewc':
        regularizer = lambda images, outputs, model, texts=None: ewc_penalty(model, old_model.state_dict())        
    else:
        regularizer = None
    
    if args.state_dict:
        # For testing weight changes in every step.
        rets = []
        for i in [14,15]:
            checkpoint_path = os.path.join(args.model_folder, 'epoch0_iter{}.pth'.format(i))
            if not os.path.exists(checkpoint_path):
                break
            print(checkpoint_path)
            try:
                state_dict = torch.load(checkpoint_path)
            except:
                state_dict = torch.load(checkpoint_path, weights_only=False)
            model.load_state_dict(state_dict)
            model.eval()
            ret = evaluations(args, model, preprocess, texts, tokenizer, criterion)
#            rets.append(ret)

        state_dict = torch.load(args.state_dict)
        model.load_state_dict(state_dict)
        model.eval()
        evaluations(args, model, preprocess, texts, tokenizer, criterion)

        pre_model, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg=args.regularization, mode=args.mode)
#        cka_path = os.path.join(args.model_folder, 'cka_results.pth') #_v2.json')
#        cka_results = calculate_cka(pre_model, model, preprocess, cka_path, args)
        if args.use_wandb:
            wandb.run.finish()
        exit()
    else:
        checkpoint_path = os.path.join(args.model_folder, 'checkpoint.pth.tar')
        bestfile = os.path.join(args.model_folder, 'model_best.pth.tar')
        if args.evaluate and os.path.exists(checkpoint_path): # and args.regularization != 'lora':
            # load model
            try:
                state_dict = torch.load(checkpoint_path)
            except:
                state_dict = torch.load(checkpoint_path, weights_only=False)
            if state_dict['epoch'] == args.epochs:
                if args.eval_best and os.path.exists(bestfile):
                    print(bestfile)
                    state_dict = torch.load(bestfile)
                model.load_state_dict(state_dict['state_dict'])
                evaluations(args, model, preprocess, texts, tokenizer, criterion)

                pre_model, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg=args.regularization, mode=args.mode)
#                cka_path = os.path.join(args.model_folder, 'cka_results.pth') #_v2.json')
#                cka_results = calculate_cka(pre_model, model, preprocess, cka_path, args)
                if args.use_wandb:
                    wandb.run.finish()
                exit()

    epochs_since_improvement = 0

    train_accs = []
    val_accs = []
    test_accs = []
    return_attention = False

    distances = []

    if args.dataset in ['caltech256', 'cars', 'cifar100', 'cub200']:
        # new zero-shot classifier
        if args.dataset == 'caltech256':
            caltech_class = [e.replace('-101','').split('.')[1] for e in train_dataset.categories]
        else:
            caltech_class = train_dataset.classes
        caltech_zero_shot_weights = build_zero_shot_classifier(model, tokenizer, caltech_class, OPENAI_IMAGENET_TEMPLATES, 50, device=device)
        original_zero_shot_weights = copy.deepcopy(model.zero_shot_weights)
        model.zero_shot_weights = caltech_zero_shot_weights


    for epoch in range(args.start_epoch):
        if 'step' in args.scheduler:
            scheduler.step()
    for epoch in range(args.start_epoch, args.epochs):
        print(epoch)
        # train for one epoch
        train_loss, train_acc = train(train_loader, texts, model, tokenizer, optimizer, scheduler, criterion, scaler, epoch, args, return_attention=return_attention, mask=mask, old_model=old_model, regularizer=regularizer)[:2]
        if 'step' in args.scheduler:
            scheduler.step()
        distance = 0
        distances.append(distance)
        print(distance)
        train_accs.append(train_acc)
#        torch.save(model.state_dict(), 'model.pth')
        if args.use_wandb:
            wandb.log({
                'train_loss': train_loss,
                'train_acc': train_acc,
                 }, step=epoch)

        if epoch % args.save_freq == 0:
            if args.dataset == 'imagenet' or args.dataset == 'imagenet-c' or args.learning_rate != 0.001:
                checkpoint_name = f'checkpoint_{epoch+1}.pth.tar'
            else:
                checkpoint_name = f'checkpoint.pth.tar'

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, args, filename=checkpoint_name)

        if (not args.no_split or epoch == args.epochs - 1 or args.dataset == 'imagenet' or args.dataset == 'cars' or args.dataset == 'cub200') and args.dataset != 'imagenet-c':
            if args.regularization == 'FLYP': # Load weight
                if 'siglip' in args.model:
                    from lib.siglip2.zero_shot_classifier import build_zero_shot_classifier as build_siglip2_zero_shot_classifier
                    zero_shot_weights = build_siglip2_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device=device)
                else:
                    zero_shot_weights = build_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device='cuda')
                model.zero_shot_weights = zero_shot_weights

            val_acc1, val_loss = validate(val_loader, texts, model, tokenizer, criterion, args, return_attention=return_attention, mask=mask)[:2]
            val_accs.append(val_acc1)
            if args.use_wandb:
                wandb.log({
                    'val_loss': val_loss,
                    'val_acc': val_acc1,
                }, step=epoch)

            # remember best acc@1 and save checkpoint
            is_best = val_acc1 > best_acc1
            best_acc1 = max(val_acc1, best_acc1)
            if is_best:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, args, is_best=is_best)

            if is_best:
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1
                print(f"There's no improvement for {epochs_since_improvement} epochs.")

                if epochs_since_improvement >= args.patience:
                    print("The training halted by early stopping criterion.")
                    break

        
    if args.dataset in ['caltech256', 'cars', 'cifar100', 'cub200']:
        model.zero_shot_weights = original_zero_shot_weights
        print("Load original zero-shot weights")
    print(train_accs)
    print(val_accs)
    print(test_accs)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, args)

    # evaluate on imagenet-c
    if args.eval_best:
        # load best checkpoint.
        bestfile = os.path.join(args.model_folder, 'model_best.pth.tar')
        if os.path.exists(bestfile):
            state_dict = torch.load(bestfile)
            model.load_state_dict(state_dict['state_dict'])
        else:
            raise Exception("Something Wrong")

    results = evaluations(args, model, preprocess, texts, tokenizer, criterion)
    
    pre_model, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg=args.regularization, mode=args.mode)
#    cka_path = os.path.join(args.model_folder, 'cka_results.pth') #_v2.json')
#    cka_results = calculate_cka(pre_model, model, preprocess, cka_path, args)
   
    use_tpl =False
    if use_tpl:
        fig = tpl.figure()
        fig.plot(np.arange(len(train_accs)), train_accs, label='Train', height=10)
        fig.plot(np.arange(len(val_accs)), val_accs, label='Val', height=10)
        fig.plot(np.arange(len(test_accs)), test_accs, label='Test', height=10)
        fig.show()
   
    if args.use_wandb:
        wandb.run.finish()

if __name__ == '__main__':
    main()
