import os
import sys
import argparse
import json 
import torch
import yaml
from easydict import EasyDict as edict
from PIL import Image
from torchvision import transforms
import copy
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models.get_model import load_model
from constants import images_normalize

from models.clip_model import clip


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
parser.add_argument("--train_config", default="")

parser.add_argument("--lines_file", type=str, required=True)
parser.add_argument("--orig_train_file", type=str, required=True)
parser.add_argument("--out_train_file", type=str, required=True)

parser.add_argument("--model", type=str, default="CLIP_ViT-B-16")
parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)

parser.add_argument("--ckpt_path", type=str)
args = parser.parse_args()

########################
###### load config ######
########################
config_path = args.config
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config = edict(config)

train_config_path = args.train_config
with open(train_config_path, "r") as f:
    train_config = yaml.safe_load(f)
train_config = edict(train_config)


########################
###### load model ######
########################
print("Loading model")
model, ref_model, tokenizer = load_model(
    config, args.model, None, args.text_encoder, device=device,
    train_config=train_config
)
if args.ckpt_path is not None:
    checkpoint = torch.load(args.ckpt_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()

########################
###### load lines ######
########################
lines = []
with open(args.lines_file, 'r') as f:
    for line in f:
        lines.append(line.strip())
generated_texts = [
    s.split("==") for s in lines
]

########################
####### load ann #######
########################
with open(args.orig_train_file, 'r') as f:
    ann = json.load(f)

#### idx
idx_1_anns = []
idx_2_anns = []
idx_3_anns = []
idx_4_anns = []
for i, a in tqdm(enumerate(ann), total=len(ann)):
    captions_list = generated_texts[i]
    captions_list = [cap for cap in captions_list if len(cap) > 0]
    if len(captions_list) == 0:
        print("No captions found for image {}".format(i))
        continue
    cap1 = captions_list[0]
    cap2 = captions_list[1] if len(captions_list) > 1 else ""
    cap3 = captions_list[2] if len(captions_list) > 2 else ""
    cap4 = captions_list[3] if len(captions_list) > 3 else ""
    ann_copy = copy.deepcopy(ann[i])
    ann_copy['caption'] = cap1
    idx_1_anns.append(ann_copy)
    ann_copy = copy.deepcopy(ann[i])
    ann_copy['caption'] = cap2
    idx_2_anns.append(ann_copy)
    ann_copy = copy.deepcopy(ann[i])
    ann_copy['caption'] = cap3
    idx_3_anns.append(ann_copy)
    ann_copy = copy.deepcopy(ann[i])
    ann_copy['caption'] = cap4
    idx_4_anns.append(ann_copy)

    if i % 500 == 0:
        print("Processed {} images".format(i))
        with open(args.out_train_file.replace(".json", "_idx1.json"), 'w') as f:
            json.dump(idx_1_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_idx2.json"), 'w') as f:
            json.dump(idx_2_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_idx3.json"), 'w') as f:
            json.dump(idx_3_anns, f, indent=4)

        with open(args.out_train_file.replace(".json", "_idx4.json"), 'w') as f:
            json.dump(idx_4_anns, f, indent=4)

save_path = args.out_train_file.replace(".json", "_idx1.json")
with open(save_path, 'w') as f:
    json.dump(idx_1_anns, f, indent=4)

save_path = args.out_train_file.replace(".json", "_idx2.json")
with open(save_path, 'w') as f:
    json.dump(idx_2_anns, f, indent=4)

save_path = args.out_train_file.replace(".json", "_idx3.json")
with open(save_path, 'w') as f:
    json.dump(idx_3_anns, f, indent=4)

save_path = args.out_train_file.replace(".json", "_idx4.json")
with open(save_path, 'w') as f:
    json.dump(idx_4_anns, f, indent=4)


#### sort by score
# test_transform = transforms.Compose([
#     transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
#     transforms.ToTensor(),
# ])   

# top_1_anns = []
# top_2_anns = []
# top_3_anns = []
# top_4_anns = []

# for i, a in tqdm(enumerate(ann), total=len(ann)):

#     captions_list = generated_texts[i]
#     captions_list = [cap for cap in captions_list if len(cap) > 0]
    
#     image_path = os.path.join(config['image_root'], a['image'])
#     image = Image.open(image_path).convert("RGB")
#     orig_caption = ann[i]['caption']

#     image = test_transform(image).unsqueeze(0).to(device)
#     image = images_normalize(image)

#     # print("Image: ", image_path)
#     # print("orig_caption: ", orig_caption)
#     # print("Generated captions: ")
#     # for caption in captions_list:
#     #     print("- ", caption)

#     image_feat = model.encode_image(image)

#     text_input_ids = clip.tokenize(captions_list + [orig_caption], truncate=True).to(device)
#     text_feats = model.encode_text(text_input_ids)

#     cos_sim = torch.nn.functional.cosine_similarity(image_feat, text_feats, dim=1)

#     # print("Cosine Similarity: ", cos_sim)

#     rank = torch.argsort(cos_sim[:-1], descending=True)

#     # print("Orig caption: {:.2f}: {}".format(cos_sim[-1], orig_caption))

#     if len(rank) == 0:
#         print("No captions found for image {}".format(i))
#         continue
#     top_1_cap = captions_list[rank[0]]
#     ann_copy = copy.deepcopy(ann[i])
#     ann_copy['caption'] = top_1_cap
#     ann_copy["cos_sim"] = cos_sim[rank[0]].item()
#     top_1_anns.append(ann_copy)
#     # print("Top 1 caption {:.2f}: {}".format(cos_sim[rank[0]], top_1_cap))

#     if len(rank) > 1:
#         top_2_cap = captions_list[rank[1]]
#         ann_copy = copy.deepcopy(ann[i])
#         ann_copy['caption'] = top_2_cap
#         ann_copy["cos_sim"] = cos_sim[rank[1]].item()
#         top_2_anns.append(ann_copy)
#         # print("Top 2 caption {:.2f}: {}".format(cos_sim[rank[1]], top_2_cap))
    
#     if len(rank) > 2:
#         top_3_cap = captions_list[rank[2]]
#         ann_copy = copy.deepcopy(ann[i])
#         ann_copy['caption'] = top_3_cap
#         ann_copy["cos_sim"] = cos_sim[rank[2]].item()
#         top_3_anns.append(ann_copy)
#         # print("Top 3 caption {:.2f}: {}".format(cos_sim[rank[2]], top_3_cap))

#     if len(rank) > 3:
#         top_4_cap = captions_list[rank[3]]
#         ann_copy = copy.deepcopy(ann[i])
#         ann_copy['caption'] = top_4_cap
#         ann_copy["cos_sim"] = cos_sim[rank[3]].item()
#         top_4_anns.append(ann_copy)
#         # print("Top 4 caption {:.2f}: {}".format(cos_sim[rank[3]], top_4_cap))
    
#     # print("\n\n")
        
#     if i % 500 == 0:
#         print("Processed {} images".format(i))
#         with open(args.out_train_file.replace(".json", "_top1.json"), 'w') as f:
#             json.dump(top_1_anns, f, indent=4)

#         with open(args.out_train_file.replace(".json", "_top2.json"), 'w') as f:
#             json.dump(top_2_anns, f, indent=4)

#         with open(args.out_train_file.replace(".json", "_top3.json"), 'w') as f:
#             json.dump(top_3_anns, f, indent=4)

#         with open(args.out_train_file.replace(".json", "_top4.json"), 'w') as f:
#             json.dump(top_4_anns, f, indent=4)