import json, os
from dataset_utils import get_clip_prompt_dataloader, get_clip_original_prompt_dataloader
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch
import torch.nn.functional as F
from RKME_utils import MMD

def top_k_accuracy(gts, dists, k=1):
    top_k_indices = np.argsort(dists, axis=1)[:, -k:]
    gts = gts.reshape((-1, 1))
    true_labels_broadcasted = np.broadcast_to(gts, top_k_indices.shape)
    top_k_correct = np.sum(top_k_indices == true_labels_broadcasted)
    top_k_acc = top_k_correct / len(gts)
    return top_k_acc

def get_clip_original_prompt_loaders(prompt_path, train=False):
    test_loaders = []
    for i, model_name in enumerate(models):
        model_config = models[model_name]
        model_path = f"./Images/{model_name}"
        dataloader = get_clip_original_prompt_dataloader(model_path, prompt_path, train=train)
        test_loaders.append((i, model_name, dataloader))
    return test_loaders

def get_clip_prompt_loaders(train=False):
    test_loaders = []
    for i, model_name in enumerate(models):
        model_config = models[model_name]
        model_path = f"./Images/{model_name}"
        bs = 1
        if train:
            bs = 32
            model_path += "/train"
        else:
            model_path += "/test"
        dataloader = get_clip_prompt_dataloader(model_path, bs=bs)
        test_loaders.append((i, model_name, dataloader))
    return test_loaders

def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    similarity = dot_product / (norm_vec1 * norm_vec2)
    return similarity

def evaluate_CLIP(trloaders, teloaders, name, weighted=True, gamma=0.02, delta=0, omit=False):
    preds, dists, gts = [], [], []
    # Load CLIP specs for each model
    specs = []
    for i, model_name, loader in trloaders:
        print(i, model_name)
        X_arr, p_arr = [], []
        for X, p in loader:
            X_arr.append(X)
            p_arr.append(p)
        X = torch.cat(X_arr)
        p = torch.cat(p_arr)
        X, p = X.cpu().float().numpy(), p.cpu().float().numpy()
        specs.append((X, p))
            
    for i, model_name, loader in teloaders:
        print(i, model_name)
        for X, p in loader:
            X = X.detach().cpu().float().numpy()
            p = p.detach().cpu().float().numpy()
            dist = []
            for k in range(len(specs)):
                # Solve weights
                weights = None
                if weighted:
                    weights = F.softmax(
                        torch.tensor(
                            [ cosine_similarity(p, specs[k][1][j]) for j in range(len(specs[k][1])) ] 
                        ).view(-1) + delta, 
                        dim=0
                    )
                # Solve MMD
                value = MMD(x1=specs[k][0], x2=X, beta1=weights, gamma=gamma, omit_term1=omit)
                dist.append(value)
            dist = np.array(dist)
            gts.append(i)
            dists.append(np.exp(-dist))
            preds.append(np.argmin(dist))
    return np.array(gts), np.array(preds), np.array(dists)

with open("ModelPool.json", "r") as f:
    models = json.load(f)
    
with open("SpecPool.json", "r") as f:
    specs = json.load(f)

tr_clip_original_prompt_loaders = get_clip_original_prompt_loaders("./PromptPool.json", train=True)
tr_clip_prompt_loaders = get_clip_prompt_loaders(train=True)
te_clip_prompt_loaders = get_clip_prompt_loaders(train=False)

import warnings
warnings.simplefilter("once")
warnings.filterwarnings('ignore')

for spec in specs:
    gts, preds, dists = None, None, None
    name = ".".join(spec["name"].split(sep=".")[:-1])
    if spec["type"] == "CLIP-fixed":  
        gts, preds, dists = evaluate_CLIP(tr_clip_original_prompt_loaders, te_clip_prompt_loaders, spec["name"], weighted=True,  gamma=spec["gamma"])
    if gts is None: continue
    path = f"./results/{name}"
    accs = []
    for k in range(1, 10):
        accs.append(top_k_accuracy(gts, dists, k))
    if not os.path.exists(path): os.makedirs(path)
    with open(os.path.join(path, "res.json"), "w") as fw:
        json.dump({
            "accs": accs,
            "acc": accuracy_score(gts, preds)
        }, fw)
    with open(os.path.join(path, "preds.json"), "w") as fw:
        json.dump({
            "gts": gts.tolist(),
            "preds": preds.tolist(),
            "dists": dists.tolist()
        }, fw)
    print(accs)