import os
import sys
import argparse
import json 
import torch
import yaml
from easydict import EasyDict as edict
from PIL import Image
from torchvision import transforms
import copy
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models.get_model import load_model
from constants import images_normalize

from models.clip_model import clip


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
parser.add_argument("--train_config", default="")

parser.add_argument("--orig_train_file", type=str, required=True)
parser.add_argument("--out_train_file", type=str, required=True)

parser.add_argument("--model", type=str, default="CLIP_ViT-B-16")
parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)

parser.add_argument("--ckpt_path", type=str)
args = parser.parse_args()

########################
out_train_file_dir = os.path.dirname(args.out_train_file)
if not os.path.exists(out_train_file_dir):
    os.makedirs(out_train_file_dir)

########################
###### load config ######
########################
config_path = args.config
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config = edict(config)

train_config_path = args.train_config
with open(train_config_path, "r") as f:
    train_config = yaml.safe_load(f)
train_config = edict(train_config)


########################
###### load model ######
########################
print("Loading model")
model, ref_model, tokenizer = load_model(
    config, args.model, None, args.text_encoder, device=device,
    train_config=train_config
)
if args.ckpt_path is not None:
    checkpoint = torch.load(args.ckpt_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()


########################
####### load ann #######
########################
with open(args.orig_train_file, 'r') as f:
    ann = json.load(f)

image_id2ann = {}
for i, a in tqdm(enumerate(ann), total=len(ann)):
    image_id = a['image_id']
    if image_id not in image_id2ann:
        image_id2ann[image_id] = []
    image_id2ann[image_id].append(a)


########################
test_transform = transforms.Compose([
    transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
    transforms.ToTensor(),
])   

top_1_anns = []
top_2_anns = []
top_3_anns = []
top_4_anns = []

for image_id, this_anns in tqdm(image_id2ann.items(), total=len(image_id2ann)):

    # Sort anns except the first one!
    # first one is the default caption we always use.
    this_anns = this_anns[1:]

    captions_list = [a["caption"] for a in this_anns]
    captions_list = [cap for cap in captions_list if len(cap) > 0]

    image_path = os.path.join(config['image_root'], this_anns[0]['image'])
    image = Image.open(image_path).convert("RGB")

    image = test_transform(image).unsqueeze(0).to(device)
    image = images_normalize(image)

    image_feat = model.encode_image(image)

    text_input_ids = clip.tokenize(captions_list, truncate=True).to(device)
    text_feats = model.encode_text(text_input_ids)
    cos_sim = torch.nn.functional.cosine_similarity(image_feat, text_feats, dim=1)
    rank = torch.argsort(cos_sim, descending=True)

    ann = this_anns[0]

    # top 1
    top_1_cap = captions_list[rank[0]]
    ann_copy = copy.deepcopy(ann)
    ann_copy['caption'] = top_1_cap
    ann_copy["cos_sim"] = cos_sim[rank[0]].item()
    top_1_anns.append(ann_copy)
    # print("Top 1 caption {:.2f}: {}".format(cos_sim[rank[0]], top_1_cap))

    if len(rank) > 1:
        top_2_cap = captions_list[rank[1]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['caption'] = top_2_cap
        ann_copy["cos_sim"] = cos_sim[rank[1]].item()
        top_2_anns.append(ann_copy)
        # print("Top 2 caption {:.2f}: {}".format(cos_sim[rank[1]], top_2_cap))
    
    if len(rank) > 2:
        top_3_cap = captions_list[rank[2]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['caption'] = top_3_cap
        ann_copy["cos_sim"] = cos_sim[rank[2]].item()
        top_3_anns.append(ann_copy)
        # print("Top 3 caption {:.2f}: {}".format(cos_sim[rank[2]], top_3_cap))

    if len(rank) > 3:
        top_4_cap = captions_list[rank[3]]
        ann_copy = copy.deepcopy(ann)
        ann_copy['caption'] = top_4_cap
        ann_copy["cos_sim"] = cos_sim[rank[3]].item()
        top_4_anns.append(ann_copy)
        # print("Top 4 caption {:.2f}: {}".format(cos_sim[rank[3]], top_4_cap))
    
    # print("\n\n")
        
    if image_id % 500 == 0:
        print("Processed {} images".format(image_id))
        with open(args.out_train_file.replace(".json", "_top1.json"), 'w') as f:
            json.dump(top_1_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top2.json"), 'w') as f:
            json.dump(top_2_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top3.json"), 'w') as f:
            json.dump(top_3_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_top4.json"), 'w') as f:
            json.dump(top_4_anns, f, indent=4)

print("Processed {} images".format(i))
with open(args.out_train_file.replace(".json", "_top1.json"), 'w') as f:
    json.dump(top_1_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top2.json"), 'w') as f:
    json.dump(top_2_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top3.json"), 'w') as f:
    json.dump(top_3_anns, f, indent=4)

with open(args.out_train_file.replace(".json", "_top4.json"), 'w') as f:
    json.dump(top_4_anns, f, indent=4)