import random
import argparse
import time 

from tqdm import tqdm
from datetime import datetime
import base64 

import torch
import torch.nn.functional as F
import operator

import clip
from utils import *
from transformers import pipeline, AutoProcessor
from PIL import Image    
import requests
import torchvision.transforms as transforms

import openai 
from datasets.utils import read_json

from tda_runner import update_cache, compute_cache_logits

target_lookup = {
    0: "A photo of Annual Crop Land", 
    1: "A photo of Forest", 
    2: "A photo of Herbaceous Vegetation Land",
    3: "A photo of Highway or Road", 
    4: "A photo of Industrial Buildings", 
    5: "A photo of Pasture Land", 
    6: "A photo of Permanent Crop Land", 
    7: "A photo of Residential Buildings", 
    8: "A photo of River", 
    9: "A photo of Sea or Lake"
}


def get_arguments():
    """Get arguments of the test-time adaptation."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', dest='config', required=True, help='settings of TDA on specific dataset in yaml format.')
    parser.add_argument('--datasets', dest='datasets', type=str, required=True, help="Datasets to process, separated by a slash (/). Example: I/A/V/R/S")
    parser.add_argument('--data-root', dest='data_root', type=str, default='./dataset/', help='Path to the datasets directory. Default is ./dataset/')
    parser.add_argument('--backbone', dest='backbone', type=str, choices=['RN50', 'ViT-B/16'], required=True, help='CLIP model backbone to use: RN50 or ViT-B/16.')
    parser.add_argument('--uncertain_thres', type=float, default=0.2)
    parser.add_argument('--labeling_budget', type=float, default=0.1)
    args = parser.parse_args()

    return args


def remove_array(T, i):
    return torch.cat([T[0:i], T[i+1:]])
    
def pairwise_similarity(t1, t2, alpha=5.0):
    t1 = torch.nn.functional.normalize(t1, dim=-1)
    t2 = torch.nn.functional.normalize(t2, dim=-1)
    sim = t1 @ t2.t()
    return sim
    
def run_test_custom(pos_cfg, neg_cfg, loader, clip_model, clip_weights, **kwargs):
    topilimage = transforms.ToPILImage()
    buffers = []
    uncert_thres = kwargs["uncert_thres"]
    pdist = torch.nn.PairwiseDistance(p=2)

    with torch.no_grad():
        pos_cache, neg_cache = {}, {}
        accuracies = {}
        tot_acc= []
        img_feature_db = torch.tensor([]).cuda().half()
        label_db = []
        
        pos_enabled, neg_enabled = pos_cfg['enabled'], neg_cfg['enabled']
        if pos_enabled:
            pos_params = {k: pos_cfg[k] for k in ['shot_capacity', 'alpha', 'beta']}
        if neg_enabled:
            neg_params = {k: neg_cfg[k] for k in ['shot_capacity', 'alpha', 'beta', 'entropy_threshold', 'mask_threshold']}

        reuse_cnt = 0
     
        label_cnt, unlabel_cnt = 0, 0
        labeled_idx = dict()
        #Test-time adaptation
        label_idx = []
        times = []
        for idx, (images, target) in enumerate(tqdm(loader, desc='Processed test images: ')):
            start = time.time()
            _, image_features, clip_logits, loss, prob_map, pred = get_clip_logits(images ,clip_model, clip_weights)
            is_reuse = False
            target, prop_entropy = target.cuda(), get_entropy(loss, clip_weights)
            #Unpack all hyperparameters
            if pos_enabled:
                update_cache(pos_cache, pred, [image_features, loss], pos_params['shot_capacity'])

            if neg_enabled and neg_params['entropy_threshold']['lower'] < prop_entropy < neg_params['entropy_threshold']['upper']:
                update_cache(neg_cache, pred, [image_features, loss, prob_map], neg_params['shot_capacity'], True)

            if pos_enabled and pos_cache:
                clip_logits += compute_cache_logits(image_features, pos_cache, pos_params['alpha'], pos_params['beta'], clip_weights)
            if neg_enabled and neg_cache:
                clip_logits -= compute_cache_logits(image_features, neg_cache, neg_params['alpha'], neg_params['beta'], clip_weights, (neg_params['mask_threshold']['lower'], neg_params['mask_threshold']['upper']))

            flag = False
            prop_entropy = get_entropy(softmax_entropy(clip_logits), clip_weights)
            
            # flag = False
            start = time.time()
            if prop_entropy > uncert_thres: # UNCERTAIN
                if len(img_feature_db) < 5:
                    pass 
                else:
                    img_sim = pairwise_similarity(image_features, img_feature_db).squeeze() 
                    K = 2
                    topk_val, topk_idx = torch.topk(img_sim, K)
                    
                    db_targets = [label_db[int(elem)] for elem in topk_idx]
                    reliability = torch.sum(topk_val) / K
                    cnt = dict() 
                    for value, gt in zip(topk_val, db_targets):
                        if gt not in cnt.keys():
                            cnt[gt] = float(value)
                        else:
                            cnt[gt] += float(value)
                    # major vote 
                    max_key = max(cnt, key=cnt.get)
                    acquisition = cnt[max_key] / sum(cnt.values())
                    
                    if acquisition == 1.0 and reliability > 0.85:  # Retrieval-augmented correction.
                        label = max_key
                        final_logits = torch.zeros(clip_weights.shape[1])
                        final_logits[label] = 1.0
                        final_logits = final_logits.unsqueeze(0).cuda()
                        flag = True
                        is_reuse = True
                        reuse_cnt += 1
                            
                if not flag:
                    if kwargs["labeling_budget"] == 0 or len(label_db) < int(kwargs["labeling_budget"]*len(loader)): 
                        img_feature_db= torch.cat([img_feature_db, image_features])
                        label_db.append(int(target))
                        final_logits = clip_logits.clone()
                        label_idx.append(idx)
                    else:
                        img_sim = pairwise_similarity(image_features, img_feature_db)
                        max_val, max_idx = img_sim.max(dim=-1)
                        nearest_target = label_db[max_idx]
                        topk_val, topk_pred = torch.topk(clip_logits.squeeze(), k=5)
                        if nearest_target in topk_pred: 
                            final_logits = torch.zeros(clip_weights.shape[-1])
                            final_logits[nearest_target] = 1.0
                            final_logits = final_logits.unsqueeze(0).cuda() 
                            reuse_cnt += 1
                        else:
                            final_logits = clip_logits.clone()

                    
            else:
                final_logits = clip_logits.clone()
                
            diff = time.time() - start
            
            acc = cls_acc(final_logits, target)  
            
            tot_acc.append(acc)
            
            if int(target) not in accuracies.keys():
                accuracies[int(target)] = [acc]
            else:  
                accuracies[int(target)].append(acc)
            
        return sum(tot_acc)/len(tot_acc)

def print_acc(tot_acc, acc):
    print("==== TDA's test accuracy ====")
    print("Overall Acc: {:.2f}".format(sum(tot_acc) / len(tot_acc)))
    for key in acc.keys():
        print("{}-th class: {:.2f}".format(key, sum(acc[key]) / len(acc[key])))

def main():
    args = get_arguments()
    config_path = args.config

    # Initialize CLIP model
    clip_model, preprocess = clip.load(args.backbone)
    clip_model.visual.eval()

    # Set random seed
    random.seed(1)
    torch.manual_seed(1)
    
    # Run TDA on each dataset
    datasets = args.datasets.split('/')
    results = []
    for dataset_name in datasets:
        # Set random seed
        random.seed(1)
        torch.manual_seed(1)
        print(f"Processing {dataset_name} dataset.")
        
        cfg = get_config_file(config_path, dataset_name)
        print("\nRunning dataset configurations:")
        print(cfg, "\n")
        
        test_loader, classnames, template = build_test_data_loader(dataset_name, args.data_root, preprocess)
        clip_weights = clip_classifier(classnames, template, clip_model)

        acc = run_test_custom(cfg['positive'], cfg['negative'], test_loader, clip_model, clip_weights, uncert_thres=args.uncertain_thres, 
                                                    labeling_budget=args.labeling_budget)
        print(f"Accuracy: {acc}")
        tmp_result = dict() 
        tmp_result["accuracy"] = acc 
        
        results.append(tmp_result)
    
    for elem in results:
        print(f'{elem["accuracy"]:.2f}')

if __name__ == "__main__":
    main()