import os
import jsonlines
import argparse
import numpy as np
from clip import clip
import torch

parser = argparse.ArgumentParser()

parser.add_argument('--target_file', type=str, default = 'dataset/test_captions.jsonl')

parser.add_argument('--generated_file', type=str, required=True)

parser.add_argument('--clip_model', default = "ViT-B/32", type = str, help = 'Name of CLIP')

parser.add_argument('--fea_root', default = "dataset/frame_fea", type = str, help = 'path of features')

parser.add_argument('--topk', type=int, default=1)

args = parser.parse_args()


frame_fea_root = args.fea_root

gt = {}

with open(args.target_file, "r", encoding="utf8") as f:
    for item in jsonlines.Reader(f): 
        gt[item['caption']] = item['shots']

results = []

with open(args.generated_file, "r", encoding="utf8") as f:
    for item in jsonlines.Reader(f): 
        results.append(item)

clip_model, _ = clip.load(args.clip_model, jit=False)
clip_model.eval()
clip_model = clip_model.cuda()


final_score = 0
temporal_score = 0
residual_sim = 0
false_num = 0
R_recall = 0
R_recall_gt = 0
avg_len = 0

print("len(results) : ", len(results))

shot_feas = []

for result in results : 
    chosen_shots = result['shots']
    
    all_fea = []
    
    for shot in chosen_shots : 
        frame_path = os.path.join(frame_fea_root, shot, 'fea.npy')
        frames_fea = np.load(frame_path)
        frames_fea = torch.from_numpy(frames_fea)
        all_fea.append(frames_fea.cpu())
    shot_feas.append(all_fea) 


max_score = max_index = 0


counter = 0

for idx, (result) in enumerate(results) : 
    text = result['caption']
    chosen_shots = result['shots']
    avg_len += len(chosen_shots)
    gt_shots = gt[text]
    
    # encode text 
    sentences = text.strip().split('.')
        
    for sentence in sentences : 
        if sentence == '' : 
            sentences.remove(sentence)

    for k in range(len(sentences)) : 
        sentences[k] = sentences[k] + '.'
    
    text_infos = []
    
    for sentence in sentences : 
        text_infos.append(clip.tokenize(sentence).squeeze(0))
    
    text_infos = torch.stack(text_infos, dim = 0).cuda()

    text_embeds = clip_model.encode_text(text_infos).float().cpu()
    
    text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
    
    # IoU
    union_set = list(set(chosen_shots + gt_shots))
    union_length = len(union_set)
    
    num = 0
    orders = []
    for n, (shot) in enumerate(chosen_shots) : 
        if shot in gt_shots : 
            num += 1
            orders.append(gt_shots.index(shot))
        else : 
            shot_embeds = shot_feas[idx][n]
            shot_embeds = torch.mean(shot_embeds, dim = 0).unsqueeze(0)
            shot_embeds = torch.nn.functional.normalize(shot_embeds, dim=-1)
            sim = text_embeds @ shot_embeds.T
            sim = sim[sim.argmax(dim = 0)]

            residual_sim += (1 - sim)
            false_num += 1

    final_score += num / union_length
    
    # temporal
    num = 0
    for i in range(len(gt_shots)) :
        if i >= len(chosen_shots) : 
            break
        elif chosen_shots[i] == gt_shots[i] : 
            num += 1
    temporal_score += num / len(gt_shots)
    
    num = max_num = 0
    for o, (order) in enumerate(orders) : 
        if o == len(orders) - 1 : 
            if num > max_num : 
                max_num = num
            break
        
        if order + 1 == orders[o+1] :
            num += 1
        else : 
            if num > max_num : 
                max_num = num
            num = 0
    
    counter += 1

final_score = final_score / counter
temporal_score = temporal_score / counter
residual_sim = residual_sim / counter

print("counter : ", counter)
print("IoU : ", final_score)
print("temporal_score : ", temporal_score)
print("residual_sim : ", residual_sim)
print("avg_len : ", avg_len / len(results))