import os
import sys
import argparse
import json 
import torch
import yaml
from easydict import EasyDict as edict
from PIL import Image
from torchvision import transforms
import copy
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models.get_model import load_model
from constants import images_normalize

from models.clip_model import clip


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
parser.add_argument("--train_config", default="")

parser.add_argument("--orig_train_file", type=str, required=True)
parser.add_argument("--out_train_file", type=str, required=True)

parser.add_argument("--model", type=str, default="CLIP_ViT-B-16")
parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)

parser.add_argument("--ckpt_path", type=str)
args = parser.parse_args()

########################
out_train_file_dir = os.path.dirname(args.out_train_file)
if not os.path.exists(out_train_file_dir):
    os.makedirs(out_train_file_dir)

########################
###### load config ######
########################
config_path = args.config
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config = edict(config)

train_config_path = args.train_config
with open(train_config_path, "r") as f:
    train_config = yaml.safe_load(f)
train_config = edict(train_config)


########################
###### load model ######
########################
print("Loading model")
model, ref_model, tokenizer = load_model(
    config, args.model, None, args.text_encoder, device=device,
    train_config=train_config
)
if args.ckpt_path is not None:
    checkpoint = torch.load(args.ckpt_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()


########################
####### load ann #######
########################
with open(args.orig_train_file, 'r') as f:
    ann = json.load(f)

# original image_id to ann
orig_image_id2ann = {}
for i, a in tqdm(enumerate(ann), total=len(ann)):
    image_id = a['image_id']
    # img0_cap0_0
    # orig_img_id = int(image_id.split("_")[1].split("img")[-1]) # coco
    # orig_img_id = int(image_id.split("_")[0].split("img")[-1]) # flickr
    orig_img_id = int(image_id.split("/")[0].split("img")[-1]) # new ver
    # cap_idx = image_id.split("_")[1].split("cap")[-1]
    # j = image_id.split("_")[2]
    if orig_img_id not in orig_image_id2ann:
        orig_image_id2ann[orig_img_id] = []
    orig_image_id2ann[orig_img_id].append(a)

# for img_id in orig_image_id2ann:
#     print(img_id)
#     print(len(orig_image_id2ann[img_id]))
# sd

########################
test_transform = transforms.Compose([
    transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
    transforms.ToTensor(),
])   

top_1_anns = []
top_2_anns = []
top_3_anns = []
top_4_anns = []
top_5_anns = []

for image_id, this_anns in tqdm(orig_image_id2ann.items(), total=len(orig_image_id2ann)):

    caption = this_anns[0]['caption']
    assert caption == this_anns[1]['caption'], (this_anns)

    image_id_list = [a["image_id"] for a in this_anns]
    image_path_list = [
        os.path.join(config['image_root'], a['image']) 
        for a in this_anns
    ]
    image_list = []
    for image_path in image_path_list:
        image = Image.open(image_path).convert("RGB")
        image = test_transform(image).unsqueeze(0).to(device)
        image = images_normalize(image)
        image_list.append(image)

    image_batch = torch.cat(image_list, dim=0)
    image_feats = model.encode_image(image_batch)

    text_input_ids = clip.tokenize(caption, truncate=True).to(device)
    text_feat = model.encode_text(text_input_ids)

    cos_sim = torch.nn.functional.cosine_similarity(text_feat, image_feats, dim=1)
    rank = torch.argsort(cos_sim, descending=True)


    ann = this_anns[0]
    _image_path_list = [a["image"] for a in this_anns]
    # top 1
    top_1_img = _image_path_list[rank[0]]
    ann_copy = copy.deepcopy(ann)
    ann_copy['image'] = top_1_img
    ann_copy['image_id'] = image_id_list[rank[0]]
    ann_copy["cos_sim"] = cos_sim[rank[0]].item()
    top_1_anns.append(ann_copy)

    if len(rank) > 1:
        top_2_img = _image_path_list[rank[1]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['image'] = top_2_img
        ann_copy['image_id'] = image_id_list[rank[1]]
        ann_copy["cos_sim"] = cos_sim[rank[1]].item()
        top_2_anns.append(ann_copy)
    
    if len(rank) > 2:
        top_3_img = _image_path_list[rank[2]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['image'] = top_3_img
        ann_copy['image_id'] = image_id_list[rank[2]]
        ann_copy["cos_sim"] = cos_sim[rank[2]].item()
        top_3_anns.append(ann_copy)

    if len(rank) > 3:
        top_4_img = _image_path_list[rank[3]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['image'] = top_4_img
        ann_copy['image_id'] = image_id_list[rank[3]]
        ann_copy["cos_sim"] = cos_sim[rank[3]].item()
        top_4_anns.append(ann_copy)

    if len(rank) > 4:
        top_5_img = _image_path_list[rank[4]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['image'] = top_5_img
        ann_copy['image_id'] = image_id_list[rank[4]]
        ann_copy["cos_sim"] = cos_sim[rank[4]].item()
        top_5_anns.append(ann_copy)
        
    if image_id % 500 == 0:
        print("Processed {} images".format(image_id))
        with open(args.out_train_file.replace(".json", "_top1.json"), 'w') as f:
            json.dump(top_1_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top2.json"), 'w') as f:
            json.dump(top_2_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top3.json"), 'w') as f:
            json.dump(top_3_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top4.json"), 'w') as f:
            json.dump(top_4_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top5.json"), 'w') as f:
            json.dump(top_5_anns, f, indent=4)


print("Processed {} images".format(i))
with open(args.out_train_file.replace(".json", "_top1.json"), 'w') as f:
    json.dump(top_1_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top2.json"), 'w') as f:
    json.dump(top_2_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top3.json"), 'w') as f:
    json.dump(top_3_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top4.json"), 'w') as f:
    json.dump(top_4_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top5.json"), 'w') as f:
    json.dump(top_5_anns, f, indent=4)