import os
import torch
import torch.nn.functional as F
import numpy as np
import random
import json
import math
import sys
from typing import Iterable
import argparse
import time
import datetime
from util import dist
import torch
from torch.utils.data import DataLoader, DistributedSampler
from collections import namedtuple
from functools import reduce
import openai

from datasets import build_videoqa_dataset, videoqa_collate_fn
from model import build_model, get_tokenizer
from args import get_args_parser
from util.misc import get_mask, adjust_learning_rate
from util.metrics import MetricLogger
from model.deberta import DebertaV2ForMaskedLM
from transformers import (
    BertTokenizer,
    DebertaV2Tokenizer,
    DebertaV2Config,
    BertConfig,
    GPT2Tokenizer
)
import json
import re
from tqdm import tqdm
import pandas as pd
device = torch.device("cuda")

PROMPTING_FOLDER = ""#folder where generation was done

def parse_args():
    """
    Parse the following arguments for a default parser
    """
    parser = argparse.ArgumentParser(
        description="Running experiments"
    )

    parser.add_argument(
        "--e",
        dest="experiment",
        help="experiment name",
        default="",
        type=str,
    )
    parser.add_argument(
        "--w",
        dest="optimize",
        help="stage 2 run",
        default="",
        type=str,
    )
    parser.add_argument(
        "--d",
        dest="dataset",
        help="which json file to use",
        default="",
        type=str,
    )
    parser.add_argument(
        "--f",
        dest="frames",
        help="how much frames",
        default=10,
        type=int,
    )
    return parser.parse_args()


def bilm(clip_id, question, answer, wrong_answers, clip_result, **kwargs):
    return get_bilm (clip_id, question, answer, wrong_answers, True, "full_bilm_pred", clip_result)


def get_bilm(clip_id, question, answer, wrong_answers, with_question, key, clip_result):
    
    if key in clip_result:
        return clip_result[key] != 4, {} 

    with torch.no_grad():
        try:
            video = torch.from_numpy(np.load(f"{features_folder}/{clip_id}.npy").astype("float32"))
        except:
            return False, {key:  -1} 
        
        if video.shape[0] != frame_count:
            return False, {key:  -1} 
        video = video.unsqueeze(0).cuda()
        video_mask = get_mask(
            torch.tensor(frame_count, dtype=torch.long).unsqueeze(0), video.size(1)
        ).cuda()
        
        if with_question == False:
            question = "?"
        else:
            if question[-1] != "?":
                question = question + "?"

            
        multiple_choice = [w.lower().strip() for w in wrong_answers]
        multiple_choice.append(answer.lower().strip())

        logits_list = []
        for choice in multiple_choice:
            text = f"Question: {question.capitalize()} Is it '{choice.capitalize()}'? {tokenizer.mask_token}."

            encoded = tokenizer(
                [text],
                add_special_tokens=True,
                max_length=300,
                padding="longest",
                truncation=True,
                return_tensors="pt")


            output = model(video=video,
                        video_mask=video_mask,
                        input_ids=encoded["input_ids"].to(device),
                        attention_mask=encoded["attention_mask"].to(device),
                        )

            logits = output["logits"]
            logits = logits[:, frame_count : encoded["input_ids"].size(1) + frame_count][encoded["input_ids"] == tokenizer.mask_token_id]
            logits_list.append(logits.softmax(-1)[:, 0].cpu())

        try: 
            yes_scores = torch.stack(logits_list, 1)[0]
        except Exception as e:
            print(logits_list)
            return False, {key: -1}#{"with_q_bilm_pred": prediction}
        
        prediction = int(torch.argmax(yes_scores).cpu())
            
        del video
        del video_mask
        del logits
        torch.cuda.empty_cache()
        
        if prediction != 4:
            return True, {key: prediction}
        return False, {key: prediction}

def main():
    result = []
    
    result_file_name = f"{frame_count}_bilm_accuracies"
    filter_results = {}
    if os.path.isfile(f"{result_folder}/{result_file_name}.json"):
        filter_results_f = open(f"{result_folder}/{result_file_name}.json") 
        filter_results = json.load(filter_results_f)
    print(len(filter_results))
    clip_id_to_res = {f"{clip['clip_id']}_{clip['qa_i']}": clip for clip in filter_results}
        
    for clip_id in tqdm(qa_data):
        clip = qa_data[clip_id]
        clip_q = clip["q"] if "q" in clip else [None, None, None]
        clip_a = clip["a"] if "a" in clip else [None, None, None]
        clip_w = clip["w"] if "w" in clip else [None, None, None]
        
        # Iterate through all three questions per clip
        for i in range(3):
            clip_result = {}
            if f"{clip_id}_{i}" in clip_id_to_res:
                clip_result = clip_id_to_res[f"{clip_id}_{i}"]
                
            question = clip_q[i]
            answer = clip_a[i]
            wrong_answers = clip_w[i]
            
            clip_result["clip_id"] = clip_id
            if "clip_url" in clip:
                clip_result["clip_url"] = clip["clip_url"]
            clip_result["qa_i"] = i
            clip_result["q"] = question
            clip_result["a"] = answer
            clip_result["w"] = wrong_answers
            clip_result["good"] = (good[i] == "good")
                   
            q_result = bilm(clip_id, question, answer, wrong_answers, clip_result, qa_id = i)
            clip_result.update(filter_extra)
            
            with open(f"{result_folder}/{result_file_name}.json", 'w') as f:
                json.dump(result, f)
    
if __name__ == "__main__":
    args = parse_args()
    if "batch" not in args.dataset:
        print("wrong dataset")
        quit()
        
    features_folder = f"./features/{args.dataset}_{args.frames}"
    frame_count = args.frames
        
    tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v2-xlarge")
    model = DebertaV2ForMaskedLM.from_pretrained(
        features_dim=768,
        max_feats=frame_count,
        freeze_lm=False,
        freeze_mlm=False,
        ft_ln=False,
        ds_factor_attn=8,
        ds_factor_ff=8,
        dropout=0.1,
        n_ans=2,
        freeze_last=False,
        pretrained_model_name_or_path="microsoft/deberta-v2-xlarge",
        )

    checkpoint = torch.load("../../frozenbilm_how2qa.pth", map_location="cpu")
    model.load_state_dict(checkpoint["model"], strict=False)
    
    model.cuda()
    model.eval()

    tok_yes = torch.tensor(tokenizer("Yes",
                                     add_special_tokens=False,
                                     max_length=1,
                                     truncation=True,
                                     padding="max_length",)["input_ids"],
                           dtype=torch.long,)

    tok_no = torch.tensor(tokenizer("No",
                                    add_special_tokens=False,
                                    max_length=1,
                                    truncation=True,
                                    padding="max_length",)["input_ids"],
                          dtype=torch.long,)


    a2tok = torch.stack([tok_yes, tok_no])
    model.set_answer_embeddings(
        a2tok.to(model.device), freeze_last=False
    )
        
    if args.optimize != "":
        experiment_path = f"{PROMPTING_FOLDER}/results/{args.dataset}_results/{args.experiment}/stage_2/{args.optimize}"
    else:
        experiment_path = f"{PROMPTING_FOLDER}/results/{args.dataset}_results/{args.experiment}"

    qa_data_path = f"{experiment_path}/all_results.json"
    result_folder = f"{experiment_path}/filter_results"
    if not os.path.exists(result_folder):
        os.mkdir(result_folder)
        
    qa_data_f = open(qa_data_path)
    qa_data = json.load(qa_data_f)
    
    main()
