import os
import sys
import argparse
import json 
import torch
import yaml
from easydict import EasyDict as edict
from PIL import Image
from torchvision import transforms
import copy
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models.get_model import load_model
from constants import images_normalize

from models.clip_model import clip


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument("--config", default="../configs/Retrieval_flickr_train.yaml")
parser.add_argument("--train_config", default="")

parser.add_argument("--aug_modal", type=str, required=True, choices=["image", "text"])

parser.add_argument("--orig_train_file", type=str, required=True)
parser.add_argument("--aug_train_file", type=str, required=True)

parser.add_argument("--model", type=str, default="CLIP_ViT-B-16")
parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)

parser.add_argument("--ckpt_path", type=str)
args = parser.parse_args()

########################
###### load config ######
########################
config_path = args.config
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config = edict(config)

train_config_path = args.train_config
with open(train_config_path, "r") as f:
    train_config = yaml.safe_load(f)
train_config = edict(train_config)


########################
###### load model ######
########################
print("Loading model")
model, ref_model, tokenizer = load_model(
    config, args.model, None, args.text_encoder, device=device,
    train_config=train_config
)
if args.ckpt_path is not None:
    checkpoint = torch.load(args.ckpt_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()


########################
####### load ann #######
########################
with open(args.orig_train_file, 'r') as f:
    orig_ann = json.load(f)
    
with open(args.aug_train_file, 'r') as f:
    aug_ann = json.load(f)


test_transform = transforms.Compose([
    transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
    transforms.ToTensor(),
])   

# assert len(orig_ann) == len(aug_ann)


if args.aug_modal == "image":
    cap2ann_orig = {a["caption"]: a for a in orig_ann}
    cap2ann_aug = {a["caption"]: a for a in aug_ann}

    orig_aug_pairs = []
    for cap, _aug_ann in cap2ann_aug.items():
        _orig_ann = cap2ann_orig[cap]
        orig_aug_pairs.append((_orig_ann, _aug_ann))
elif args.aug_modal == "text":
    image2ann_orig = {a["image"]: a for a in orig_ann}
    image2ann_aug = {a["image"]: a for a in aug_ann}

    orig_aug_pairs = []
    for image, _aug_ann in image2ann_aug.items():
        _orig_ann = image2ann_orig[image]
        orig_aug_pairs.append((_orig_ann, _aug_ann))
print("Length of pairs: ", len(orig_aug_pairs))


def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = test_transform(image).unsqueeze(0).to(device)
    image = images_normalize(image)
    return image


print("Calculating metric...")
aug_anns_new = []
for i, (orig_a, aug_a) in tqdm(enumerate(orig_aug_pairs), total=len(orig_ann)):
    orig_cap = orig_a['caption']
    aug_cap = aug_a['caption']
    orig_image_path = os.path.join(config['image_root'], orig_a['image'])
    aug_image_path = os.path.join(config['image_root'], aug_a['image'])

    if args.aug_modal == "image":
        assert orig_cap == aug_cap
        cap_feat = model.encode_text(clip.tokenize([orig_cap], truncate=True).to(device))
        # print(preprocess_image(orig_image_path).shape)
        orig_image_feat = model.encode_image(preprocess_image(orig_image_path))
        aug_image_feat = model.encode_image(preprocess_image(aug_image_path))

        diff = torch.nn.functional.cosine_similarity(orig_image_feat, aug_image_feat, dim=1)
        bef_aft_l2dist = torch.nn.functional.pairwise_distance(orig_image_feat, aug_image_feat, p=2)
        alignment = torch.nn.functional.cosine_similarity(aug_image_feat, cap_feat, dim=1)

    elif args.aug_modal == "text":
        assert orig_image_path == aug_image_path
        image_feat = model.encode_image(preprocess_image(orig_image_path))
        text_input_ids = clip.tokenize([orig_cap, aug_cap], truncate=True).to(device)
        text_feats = model.encode_text(text_input_ids)

        diff = torch.nn.functional.cosine_similarity(text_feats[0], text_feats[1], dim=0)
        bef_aft_l2dist = torch.nn.functional.pairwise_distance(text_feats[0], text_feats[1], p=2)
        alignment = torch.nn.functional.cosine_similarity(image_feat, text_feats[1], dim=1)

    aug_a["diff"] = diff.item()
    aug_a["bef_aft_l2dist"] = bef_aft_l2dist.item()
    aug_a["alignment"] = alignment.item()

    aug_anns_new.append(aug_a)

    if i % 500 == 0:
        print("Processed {} images".format(i))
        with open(args.aug_train_file.replace(".json", "_metric.json"), 'w') as f:
            json.dump(aug_anns_new, f, indent=4)

with open(args.aug_train_file.replace(".json", "_metric.json"), 'w') as f:
    json.dump(aug_anns_new, f, indent=4)
