import os
import sys
import argparse
import pickle
sys.path.append('/home/chois/woo/MMRL/allenact')
from allenact_plugins.clip_plugin import MVPT

import clip
import torch
import torchvision.transforms as T
from PIL import Image
import numpy as np

from bc.features_extractor import freeze_model, convert_models_to_fp32
from bc.MVPT_ConPE import CONPEMultiVisualPromptTuningCLIP, _CONPEMultiVisualPromptTuningCLIP


parser = argparse.ArgumentParser()
parser.add_argument("--e", action="store_true", default=False)
parser.add_argument("--bright", type=float, default=1.0)
args = parser.parse_args()
print("========================")
print(f"BRIGHTNESS: {args.bright}")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# make clip model
clip_model, _ = clip.load("ViT-B/32", device=device)
for module in clip_model.modules():
    if "BatchNorm" in type(module).__name__:
        module.momentum=0.0
convert_models_to_fp32(clip_model)
clip_model = freeze_model(clip_model)

if args.e:
    model = _CONPEMultiVisualPromptTuningCLIP(clip_model, device)
else:
    model = CONPEMultiVisualPromptTuningCLIP(clip_model, device)
model.eval()
prompt_paths = []
domain_factors = ['BRIGHT']
for df in domain_factors:
    prompt_paths.append(f'/home/chois/woo/MMRL/logs/window-open/{df}/checkpoints/contrastive__latest.pth')
model.prompt_init(prompt_paths, multi_p_mode=['COMPOSE', 'UNIFORM', 'AVG'])
model.cuda()

# Turning off gradients
for name, param in model.named_parameters():
    if "prompt" not in name:
        param.requires_grad_(False)
for name, params in clip_model.named_parameters():
    param.requires_grad_(False)
    assert not param.requires_grad

prompts_embeddings = model.visual_backbone.source_prompt_list

print(prompts_embeddings)

# aug
to_tensor_aug = T.Compose([
    T.ColorJitter(
        brightness=(args.bright, args.bright),
        #contrast=(args.contrast, args.contrast),
    ),
    T.ToTensor(),
    T.Normalize((0.485, 0.465, 0.406), (0.229, 0.224, 0.225))
])


# dataset load
data_path = f'/home/chois/woo/MMRL/trajdata/RoboMani/window-close/CAMS/train_dataset.pkl'
with open(data_path, 'rb') as f:
    data = pickle.load(f)

print(data.keys())
del data["classes"], data["actions"]
images = []
for i in [8]:
    print(data[i].keys())
    for image in data[i]['frame']:
        image = Image.fromarray(image)
        image = to_tensor_aug(image).unsqueeze(0).to(device)
        images.append(image)

""" In Image Level"""
if not args.e:
    for k, prompt in enumerate(prompts_embeddings):
        cos_sim, dist_sim = [], []
        for img in images:
            input_x = model(img.cuda())
            avg_x = torch.mean(input_x, dim=1)
            avg_p = torch.mean(prompt, dim=1)
            cos = torch.nn.functional.cosine_similarity(avg_x, avg_p, dim=1)
            dist = torch.cdist(avg_x, avg_p, p=2)
            
            cos_sim.append(cos.item())
            dist_sim.append(dist.item())

        cos_sim, dist_sim = np.array(cos_sim), np.array(dist_sim)
        print(f"cossim: {np.mean(cos_sim)*100:.4f}")
        print(f"distim: {np.mean(dist_sim):.4f}")


""" In embedding level """
if args.e:
    for k, prompt in enumerate(prompts_embeddings):
        cos_sim, dist_sim = [], []
        for img in images:
            ori_x = clip_model.visual(img.cuda())
            input_x = model(img.cuda())
            cos = torch.nn.functional.cosine_similarity(ori_x, input_x[k].unsqueeze(0), dim=1)
            dist = torch.cdist(ori_x, input_x[k].unsqueeze(0), p=2)

            cos_sim.append(cos.item())
            dist_sim.append(dist.item())

        cos_sim, dist_sim = np.array(cos_sim), np.array(dist_sim)
        print(f"cossim: {np.mean(cos_sim)*100:.4f}")
        print(f"distim: {np.mean(dist_sim):.4f}")
