import argparse

import time

from copy import deepcopy
import random
from PIL import Image
import numpy as np
import operator
import matplotlib.pyplot as plt
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from tqdm import tqdm

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
import torchvision.models as models

from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagenet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset, MaskImgAugmenter
from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_model_weight, set_random_seed
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask


model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return logits[idx], idx

def avg_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

def softmax_entropy(x):
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def spatial_IoU_v2(attention_map, logit_scale, args):
    topk = int(attention_map.size(0) / 2)
    
    # attention_map_gray = torch.mean(attention_map, dim=1 ,keepdim=True)
    attention_map_gray = torch.sum(attention_map, dim=1 ,keepdim=True) # [196, 1]
    attention_map_top = torch.topk(attention_map_gray, k=topk, dim=0)[1] # [top_50, 1]
    attention_map_top = attention_map_top.squeeze(1)
    
    foreground = logit_scale * attention_map[attention_map_top] # [M, C]; M present foreground
    foreground, selected_idx = select_confident_samples(foreground, args.local_views)
    
    return foreground

def test_time_tuning(model, inputs, cache, optimizer, scaler, args):
    if args.cocoop:
        image_feature, pgen_ctx = inputs
        pgen_ctx.requires_grad = True
        optimizer = torch.optim.AdamW([pgen_ctx], args.lr)
    
    selected_idx = None
    for j in range(args.tta_steps):
        with torch.cuda.amp.autocast():
            if args.cocoop:
                output = model((image_feature, pgen_ctx))
            else:
                output, global_features, text_features, spatial_features, logit_scale = model(inputs, args.arch)
                
                ############### CLASS-SPECIFIC DESCRIPTIONS AUGMENTATION
                attention_map = spatial_features @ text_features.t() # [196, C]
                num_masks = args.language_views
                repeat_attention_maps = attention_map.unsqueeze(0).repeat(num_masks, 1, 1) # [num, 196, C]
                mask = torch.ones(num_masks, spatial_features.size(0))
                min_masked_patches = int(spatial_features.size(0) * 0.1)
                max_masked_patches = int(spatial_features.size(0) * 0.5)
    
                for i in range(num_masks):
                    num_masked_patches = random.randint(min_masked_patches, max_masked_patches)
                    masked_indices = random.sample(range(spatial_features.size(0)), num_masked_patches)
                    mask[i, masked_indices] = 0
                
                mask = mask.unsqueeze(-1).expand_as(repeat_attention_maps).cuda()
                aug_attention_maps = repeat_attention_maps * mask # (num,196,C)
                aug_spatial_attention = F.softmax(aug_attention_maps, dim=2) # (num, 196, C)
                aug_text_features = aug_spatial_attention.transpose(1,2) @ spatial_features #(num, C, 512)
                
                aug_text_features = text_features + args.visual_prior_factor * aug_text_features
                aug_logits = logit_scale * global_features.unsqueeze(0) @ aug_text_features.transpose(1,2)
                aug_logits = aug_logits.squeeze(1)
                
                ############### LOCAL VISUAL PERCEPTION
                spatial_output = spatial_IoU_v2(attention_map, logit_scale, args)
                
                ############### Test Tiem Tuning
                logits = torch.cat((output, spatial_output, aug_logits), dim=0)
                logits, selected_idx = select_confident_samples(logits, args.selection_p)
                
                loss = avg_entropy(logits)

        optimizer.zero_grad()
        # compute gradient and do SGD step
        scaler.scale(loss).backward()
        # Unscales the gradients of optimizer's assigned params in-place
        scaler.step(optimizer)
        scaler.update()
    if args.cocoop:
        return pgen_ctx

    return text_features


def main():
    args = parser.parse_args()
    set_random_seed(args.seed)

    # This codebase has only been tested under the single GPU setting
    assert args.gpu is not None
    main_worker(args.gpu, args)

def main_worker(gpu, args):
    args.gpu = gpu
    set_random_seed(args.seed)
    print("Use GPU: {} for training".format(args.gpu))

    # create model (zero-shot clip model (ViT-L/14@px336) with promptruning)
    if args.test_sets in fewshot_datasets:
        classnames = eval("{}_classes".format(args.test_sets.lower()))
    else:
        classnames = imagenet_classes
    if args.cocoop:
        model = get_cocoop(args.arch, args.test_sets, 'cpu', args.n_ctx)
        assert args.load is not None
        load_model_weight(args.load, model, 'cpu', args) # to load to cuda: device="cuda:{}".format(args.gpu)
        model_state = deepcopy(model.state_dict())
    else:
        model = get_coop(args.arch, args.test_sets, args.gpu, args.n_ctx, args.ctx_init)
        if args.load is not None:
            print("Use pre-trained soft prompt (CoOp) as initialization")
            pretrained_ctx = torch.load(args.load)['state_dict']['ctx']
            assert pretrained_ctx.size()[0] == args.n_ctx
            with torch.no_grad():
                model.prompt_learner.ctx.copy_(pretrained_ctx)  # Model has a prompt learner submodule that has an attribute CTX to which we copy the prompts!
                model.prompt_learner.ctx_init_state = pretrained_ctx
        model_state = None

    for name, param in model.named_parameters():
        if not args.cocoop: # MaPLe and CoOp
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        else:
            if "text_encoder" not in name:
                param.requires_grad_(False)
    
    print("=> Model created: visual backbone {}".format(args.arch))
    
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    else:
        assert args.gpu is not None
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # define optimizer
    if args.cocoop:
        optimizer = None
        optim_state = None
    else:
        trainable_param = model.prompt_learner.parameters()
        optimizer = torch.optim.AdamW(trainable_param, args.lr)
        optim_state = deepcopy(optimizer.state_dict())

    # setup automatic mixed-precision (Amp) loss scaling
    scaler = torch.cuda.amp.GradScaler(init_scale=1000)

    print('=> Using native Torch AMP. Training in mixed precision.')

    cudnn.benchmark = True

    # norm stats from clip.load()
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])

    
    # iterating through eval datasets
    datasets = args.test_sets.split("/")
    results = {}
    for set_id in datasets:
        if args.tpt:
            if not args.mask:
                base_transform = transforms.Compose([
                    transforms.Resize(args.resolution, interpolation=BICUBIC),
                    transforms.CenterCrop(args.resolution)])
                preprocess = transforms.Compose([
                    transforms.ToTensor(),
                    normalize])
                data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size-1, 
                                                augmix=len(set_id)>1)
                batchsize = 1
            else:
                base_transform = transforms.Compose([
                    transforms.Resize((args.resolution, args.resolution), interpolation=BICUBIC)])
                preprocess = transforms.Compose([
                    transforms.ToTensor(),
                    normalize])
                data_transform = MaskImgAugmenter(base_transform, preprocess, n_views=args.batch_size-1, 
                                                augmix=len(set_id)>1, mask_ratio=args.mask_ratio)
                batchsize = 1
        else:
            data_transform = transforms.Compose([
                transforms.Resize(args.resolution, interpolation=BICUBIC),
                transforms.CenterCrop(args.resolution),
                transforms.ToTensor(),
                normalize,
            ])
            batchsize = args.batch_size

        print("evaluating: {}".format(set_id))
        # reset the model
        # Reset classnames of custom CLIP model
        if len(set_id) > 1: 
            # fine-grained classification datasets
            classnames = eval("{}_classes".format(set_id.lower()))
        else:
            assert set_id in ['A', 'R', 'K', 'V', 'I']
            classnames_all = imagenet_classes
            classnames = []
            if set_id in ['A', 'R', 'V']:
                label_mask = eval("imagenet_{}_mask".format(set_id.lower()))
                if set_id == 'R':
                    for i, m in enumerate(label_mask):
                        if m:
                            classnames.append(classnames_all[i])
                else:
                    classnames = [classnames_all[i] for i in label_mask]
            else:
                classnames = classnames_all
        if args.cocoop:
            model.prompt_generator.reset_classnames(classnames, args.arch)
            model = model.cpu()
            model_state = model.state_dict()
            model = model.cuda(args.gpu)
        elif args.maple:
            model.reset_classnames(classnames, args)   # Reset classnames if variant of Imagenet is used
        else:
            model.reset_classnames(classnames, args.arch)   # Reset classnames if variant of Imagenet is used

        val_dataset = build_dataset(set_id, data_transform, args.data, mode=args.dataset_mode)
        print("number of test samples: {}".format(len(val_dataset)))
        val_loader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=batchsize, shuffle=True,
                    num_workers=args.workers, pin_memory=True)

        results[set_id] = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args)
        del val_dataset, val_loader
        try:
            print("=> Acc. on testset [{}]: @1 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1]))
        except:
            print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id]))

    print("======== Result Summary ========")
    print("params: nstep	lr	bs")
    print("params: {}	{}	{}".format(args.tta_steps, args.lr, args.batch_size))
    print("\t\t [set_id] \t\t Top-1 acc. \t\t Top-5 acc.")
    for id in results.keys():
        print("{}".format(id), end="	")
    print("\n")
    for id in results.keys():
        print("{:.2f}".format(results[id][0]), end="	")
    print("\n")


def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)

    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1, top5],
        prefix='Test: ')

    # reset model and switch to evaluate mode
    model.eval()
    if not args.cocoop: # no need to reset cocoop because it's fixed
        with torch.no_grad():
            model.reset()
            cache = {}
    end = time.time()
    t_s = time.time()
    for i, (images, target) in enumerate(tqdm(val_loader)):
        assert args.gpu is not None
        if isinstance(images, list):
            for k in range(len(images)):
                images[k] = images[k].cuda(args.gpu, non_blocking=True)
            image = images[0]
        else:
            if len(images.size()) > 4:
                # when using ImageNet Sampler as the dataset
                assert images.size()[0] == 1
                images = images.squeeze(0)
            images = images.cuda(args.gpu, non_blocking=True)
            image = images
        target = target.cuda(args.gpu, non_blocking=True)
        if args.tpt:
            images = torch.cat(images, dim=0)

        # reset the tunable prompt to its initial state
        if not args.cocoop: # no need to reset cocoop because it's fixed
            if args.tta_steps > 0:
                with torch.no_grad():
                    model.reset()
            optimizer.load_state_dict(optim_state)
            text_features1 = test_time_tuning(model, images, cache, optimizer, scaler, args)
            # output = solve_mta(model, images, args)
            
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                output, global_features, text_features, spatial_features, logit_scale = model(image, args.arch)
                ########### Test Time Inference
                attention_map = spatial_features @ text_features.t() # [196, C]
                softmax_attention_map = F.softmax(attention_map, dim=1)
                spatial_text = softmax_attention_map.t() @ spatial_features  # (C, 512)
                spatial_text = text_features + args.inference_factor * spatial_text
                output = logit_scale * global_features @ spatial_text.t()

        softmax_output = F.softmax(output, dim=1)
        max_confidence, max_index = torch.max(softmax_output, 1)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
                
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i+1) % args.print_freq == 0:
            progress.display(i)
    t_e = time.time()
    print("Time:{}".format((t_e - t_s)))

    progress.display_summary()

    return [top1.avg, top5.avg]


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test-time Prompt Tuning')
    parser.add_argument('data', metavar='DIR', help='path to dataset root')
    parser.add_argument('--test_sets', type=str, default='A/R/V/K/I', help='test dataset (multiple datasets split by slash)')
    parser.add_argument('--dataset_mode', type=str, default='test', help='which split to use: train/val/test')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='RN50')
    parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N')
    parser.add_argument('--local_views', default=0.1, type=float, help='local_views')
    parser.add_argument('--language_views', default=4, type=int, help='local_views')
    parser.add_argument('--visual_prior_factor', default=0.1, type=float, help='visual-prior factor')
    parser.add_argument('--inference_factor', default=0.1, type=float, help='inference factor')
    parser.add_argument('--lr', '--learning-rate', default=5e-3, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('-p', '--print-freq', default=200, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--gpu', default=0, type=int,
                        help='GPU id to use.')
    parser.add_argument('--tpt', action='store_true', default=False, help='run test-time prompt tuning')
    parser.add_argument('--mask', action='store_true', default=False, help='Perform masking augmentation')
    parser.add_argument('--mask_ratio', default=0.3, type=float, help='masking ratio')
    parser.add_argument('--selection_p', default=0.1, type=float, help='confidence selection percentile')
    parser.add_argument('--tta_steps', default=1, type=int, help='test-time-adapt steps')
    parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens')
    parser.add_argument('--maple_depth', default=3, type=int, help='Depth of MaPLe prompting')
    parser.add_argument('--ctx_init', default=None, type=str, help='init tunable prompts')
    parser.add_argument('--cocoop', action='store_true', default=False, help="use cocoop's output as prompt initialization")
    parser.add_argument('--maple', action='store_true', default=False, help="use MaPLe's output as prompt initialization")
    parser.add_argument('--load', default=None, type=str, help='path to a pre-trained coop/cocoop')
    parser.add_argument('--seed', type=int, default=0)

    main()