from collections import defaultdict
import io
import os
os.environ["TOKENIZERS_PARALLELISM"]="false"
import os.path as op
import numpy as np
from openai import OpenAI
import json
import base64
from tqdm import tqdm
import imageio
import pdb
import argparse
from PIL import Image
import matplotlib.pyplot as plt
import glob
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration


DEFAULT_NUM_FRAMES = 16.0


def save_json_file(data, file_path):
    with open(file_path, "w") as file:
        json.dump(data, file, indent=4, sort_keys=True)


# Function to encode the image
def encode_image_from_path(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


def encode_image(image):
    if isinstance(image, Image.Image):  # Check if input is a PIL Image
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")  # Save PIL Image to buffer in JPEG format
    elif isinstance(image, np.ndarray):  # Check if input is a NumPy array
        image = Image.fromarray(image.astype("uint8"))  # Convert NumPy array to PIL Image
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")  # Save PIL Image to buffer in JPEG format
    else:
        raise TypeError("Input must be a PIL Image or a NumPy array")

    # Encode the buffer to base64
    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
    return img_str


def plot_frames(frames, indices, plot_indices, fname="tmp.jpg"):
    if len(indices) == 1:
        Image.fromarray(frames[indices[0]]).save(fname)
        return frames[indices[0]]
    h, w = frames[0].shape[:2]
    fig, axes = plt.subplots(1, len(indices), figsize=((w/h*3)*len(indices), 3))
    for i, (j, k) in enumerate(zip(indices, plot_indices)):
        axes[i].imshow(frames[j])
        axes[i].axis('off')
        axes[i].set_title(f"Frame {k}", fontsize=20)
    
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()
    pil_img = Image.open(fname)
    return pil_img


def plot_single_frame(frame, plot_idx, fname="tmp.jpg"):
    w = int(frame.shape[1] / frame.shape[0] * 3.0)
    fig = plt.figure(figsize=(w, 3.5))
    plt.imshow(frame)
    plt.axis("off")
    plt.title(f"Frame {plot_idx}", fontsize=20)
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()
    pil_img = Image.open(fname)
    return pil_img


def form_gpt_question(q, frames: str, frame_idx: list, plot_frame_idx: list, video_model: str, single_image=True):
    img_fname = f"tmp_{video_model}.jpg"
    if single_image:
        vision_content = [
            {
                "type": "image_url",
                "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(plot_frames(frames, frame_idx, plot_frame_idx, img_fname))}",
                },
            }
        ]
    else:
        vision_content = [
            {
                "type": "image_url",
                "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(plot_single_frame(frames[idx], plot_idx, img_fname))}",
                },
            }
            for idx, plot_idx in zip(frame_idx, plot_frame_idx)
        ]

    vision_description = ", ".join([str(idx) for idx in plot_frame_idx])
    vision_description = f"The images show Frame {vision_description} of a video. "
    text_content = [{"type": "text", "text": vision_description + q + " Answer yes or no and then justify your answer in one sentence.\n"}]
    return text_content + vision_content


def form_question(model, q, frames, frame_idx, plot_frame_idx, video_model, single_image=True, processor=None):
    if 'gpt' in model:
        return form_gpt_question(q, frames, frame_idx, plot_frame_idx, video_model, single_image=single_image)
    elif "llava" in model:
        image = plot_frames(frames, frame_idx, plot_frame_idx, f"tmp_{video_model}.jpg")
        vision_description = ", ".join([str(idx) for idx in plot_frame_idx])
        vision_description = f"The images show Frame {vision_description} of a video. "
        prompt = "[INST] <image>\n" + vision_description + q + " Answer yes or no and then justify your answer in one sentence. [/INST]"
        inputs = processor(prompt, image, return_tensors="pt").to("cuda")
        return inputs
    else:
        raise NotImplementedError
    

def eval_forward(model, client, inputs, processor=None):
    if 'gpt' in model:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a powerful and concise vision language model that answer questions based on the images shown to you."},
                {"role": "user", "content": inputs}
                ],
            max_tokens=1000,
        )
        return response
    elif "llava" in model:
        output = client.generate(**inputs, max_new_tokens=200)
        output = processor.decode(output[0, inputs.input_ids.size(1):], skip_special_tokens=True)
        return output
    else:
        raise ValueError("Invalid model")


def idx_mapper(indices, n_frames):
    return sorted(set([round((idx-1)/(DEFAULT_NUM_FRAMES-1) * n_frames) for idx in indices]))


def map_score(score):
    if score < 0.9:
        return 0
    elif score > 0.98:
        return 1
    else:
        # Mapping the value linearly between 0 and 1 for the range [0.9, 0.98]
        return (score - 0.9) / (0.98 - 0.9)
    

def get_eval_scores(eval_answers, model, model_type, use_frame_consistency=False):
    prompts = json.load(open(f"./{model_type}_prompts.json"))
    id2type = {d['id']: d['type'] for d in prompts}

    avg_scores = []
    TCR = []
    TCR_by_type = defaultdict(list)
    score_by_type = defaultdict(list)
    for k, v in eval_answers.items():
        if use_frame_consistency:
            try:
                transition_completed = all([q['correct'] for q in v if q['type'] != 'other objects'])
                frame_consistency_score = json.load(open(f"eval_results/i2v_consistency/consecutive_frame_sim/i2v/{model}.json"))[k]
                score = np.mean([q['correct'] for q in v if q['type']!='transition object consistency'])
                score = 2/3*score + 1/3*map_score(frame_consistency_score)
            except Exception as e:
                transition_completed = all([q['correct'] for q in v if q['type'] != 'other objects'])
                score = np.mean([q['correct'] for q in v])    
        else:
            transition_completed = all([q['correct'] for q in v if q['type'] != 'other objects'])
            score = np.mean([q['correct'] for q in v])

        avg_scores.append(score)
        TCR.append(transition_completed)
        TCR_by_type[id2type[int(k.split("-")[0])]].append(transition_completed)
        score_by_type[id2type[int(k.split("-")[0])]].append(score)
    return TCR, TCR_by_type, np.mean(avg_scores), score_by_type


def eval_model(args):
    modelname = args.model
    meta_folder = f"{args.eval_model}_eval"
    save_dir = op.join(f"eval_results/{meta_folder}/{args.model_type}")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(op.join(save_dir, "tmp"), exist_ok=True)
    os.makedirs(f"eval_results/{meta_folder}/errors", exist_ok=True)

    if "gpt" in args.eval_model:
        assert args.api_key is not None, "Please provide an API key for OpenAI"
        client = OpenAI(api_key=args.api_key)
    elif args.eval_model == "llava":
        client = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", 
                                                                   torch_dtype=torch.bfloat16, 
                                                                   low_cpu_mem_usage=True, 
                                                                   use_flash_attention_2=True,
                                                                   device_map="cuda") 
        processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
        client.generation_config.pad_token_id = processor.tokenizer.pad_token_id
    else:
        raise ValueError("Invalid eval model")

    eval_questions = json.load(open(f"{args.model_type}_assertions.json"))
    video_files = sorted(glob.glob(os.path.join(args.video_dir, "*.mp4")))

    answers = {} 
    errors = {}
    if op.exists(op.join(save_dir, f"{modelname}.json")):
        answers = json.load(open(op.join(save_dir, f"{modelname}.json")))

    price = 0
    pbar = tqdm(video_files)
    for video_file in pbar:
        video_id = os.path.basename(video_file)[:-4]
        prompt_id = str(int(video_id.split("-")[0]))
        if video_id in answers:
            continue

        video_reader = imageio.get_reader(video_file)
        frames = [f for f in video_reader]
        sample_answers = []
        for question_meta in eval_questions[prompt_id]:
            frame_idx = idx_mapper(question_meta['frames'], len(frames)-1)

            q2gpt = form_question(args.eval_model, question_meta['assertion'], frames, frame_idx, question_meta['frames'], modelname, single_image=True, processor=processor) # single image gives better results
            try:
                response = eval_forward(args.eval_model, client, q2gpt, processor=processor)
                gpt_answer = response.strip().strip(".").lower()
                price += 0
            except Exception as e:
                errors[video_id] = str(e)
                print(e)
                continue

            if 'yes' in gpt_answer:
                correct = 1
            elif 'no' in gpt_answer:
                correct = 0
            else:
                correct = 0.5

            sample_answers.append({
                "assertion": question_meta['assertion'],
                "frames": question_meta['frames'],
                "type": question_meta['type'],
                "correct": correct,
                "gpt_answer": gpt_answer
            })

        if video_id not in errors:
            answers[video_id] = sample_answers
            save_json_file(answers, op.join(save_dir, f"{modelname}.json"))
        else:
            save_json_file(errors, f"eval_results/{meta_folder}/errors/{args.model_type}_{modelname}_errors.json") 
        pbar.set_postfix({"Price": round(price, 5)})
    pbar.close()

    # TODO: add consecutive frame similarity for I2V models
    TCR, TCR_by_type, avg_score, score_by_type = get_eval_scores(answers, modelname, args.model_type, use_frame_consistency=args.model_type=='i2v')
    print(modelname)
    for k, v in TCR_by_type.items():
        print("{}: {:.3f}, {:.4f}".format(k, np.mean(v)*100, np.mean(score_by_type[k])))
    print("overall: {:.3f}, {:.4f}".format(np.mean(TCR)*100, avg_score))


if __name__ == '__main__':
    parser = argparse.ArgumentParser("GPT-4V eval")
    parser.add_argument("--model_type", type=str, required=True, choices=["t2v", "i2v"])
    parser.add_argument("--model", type=str, required=True, help="This should be the model name under ./generated_videos/{i2v, t2v}/")
    parser.add_argument("--video_dir", type=str, required=True)
    parser.add_argument("--api_key", type=str, default=None)
    parser.add_argument("--eval_model", type=str, choices=['llava'], default='llava')
    args = parser.parse_args()
    eval_model(args)