import argparse
import torch
import os
import os.path as osp
import json
from tqdm import tqdm
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.custom_conversation import conv_templates
from llava.model.custom_builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path

from PIL import Image
import math
import yaml
from types import SimpleNamespace

from my_utils.save_method import MetadataSaver, ValueMonitor
from my_utils import FeatureSaver
from logic import LogicEngine, DimProspector, HeadFork, VARProc

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


def eval_model(args):
    
    with open(args.exp_config, "r") as file:
        config_dict = yaml.safe_load(file)
    cfgs = SimpleNamespace(**config_dict)

    device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {torch.cuda.get_device_name()}-{args.gpu}")
    label_width = 20
    title_width = label_width + 10

    print(
            "="*30 + "\n"
            + "Experiment Settings".center(30) + "\n"
            + "-"*30 + "\n"
            + f"{'Model Path:'.ljust(label_width)} {cfgs.model_path.rjust(5)}\n"
            + f"{'Answer File:'.ljust(label_width)} {str(args.answer_file_ver).rjust(5)}\n"
            + "="*30 + "\n"
            + "Logic Settings".center(30) + "\n"
            + "-"*30 + "\n"
            + f"{'VAR:'.ljust(label_width)} {str(cfgs.var).rjust(5)}\n"
            + f"{'Logic:'.ljust(label_width)} {str(cfgs.logic).rjust(5)}\n"
            + f"{'- tau:'.ljust(label_width)} {str(cfgs.tau).rjust(5)}\n"
            + f"{'Head Select:'.ljust(label_width)} {str(cfgs.head_fork).rjust(5)}\n"
            + f"{'- rho:'.ljust(label_width)} {str(cfgs.rho).rjust(5)}\n"
            + f"{'- summ:'.ljust(label_width)} {str(cfgs.summ).rjust(5)}\n"
            + "="*30
        )

    disable_torch_init()
    model_path = os.path.expanduser(cfgs.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, attn_implementation="eager", device_map=device)

    # export config

    # custom setting
    MetadataSaver.set_vis_len(model.get_model().get_num_vis_patches())
    if cfgs.logic == 1:
        LogicEngine.export_model_config(model.config)
        LogicEngine.set_llm_name(model_name)
        LogicEngine.rule_setting({"sink_rule": cfgs.sink_rule, "head_rule": cfgs.head_rule})
        LogicEngine.set_sink_select_layers(cfgs.sink_select_layers)
        LogicEngine.set_flag(True)
        
        if cfgs.var >= 1:
            VARProc.activate(cfgs.p)
            VARProc.config_last_layer(cfgs.except_last_layer)
            VARProc.STEP = cfgs.var
        if cfgs.dim_prospector:
            DimProspector.activate(tau=cfgs.tau)
        if cfgs.head_fork:
            HeadFork.activate(rho=cfgs.rho, summ=cfgs.summ)
            # HeadFork.set_thres(rho=cfgs.rho, tau=cfgs.summ)

    # Load questions and open answers file
    question_file_name = f"{args.category}.questions.jsonl" if args.category != "" else "questions.jsonl"
    question_file_path = osp.join(args.question_file, question_file_name)
    questions = [json.loads(q) for q in open(os.path.expanduser(question_file_path), "r")]
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    answers_file = osp.join("answers", model_name, f"{args.answer_file_ver}.jsonl")
    file_mode = "w"
        
    if osp.exists(osp.dirname(answers_file)):
        check_ans_dir = os.listdir(osp.dirname(answers_file))
        for af in check_ans_dir:
            if osp.basename(args.answer_file_ver)[:-13] == osp.splitext(af)[0][:-13]:
                answers_file = osp.join("answers", model_name, osp.dirname(args.answer_file_ver), af)
                print(f"Answer file already exists: {answers_file}")
                print("We will overwrite the file.")
                file_mode = "a"
                questions = questions[open(osp.join("answers", model_name, osp.dirname(args.answer_file_ver), osp.basename(answers_file)), "r").read().count("\n") :]
                break

    os.makedirs(osp.dirname(answers_file), exist_ok=True)
    print(f"Answer file path: {answers_file}")
    answers_file = os.path.expanduser(answers_file)

    with open(answers_file, file_mode) as ans_file:
        setattr(model, "tokenizer", tokenizer)
        for line in tqdm(questions):
            # try:    
            qid = line.get("qid", None) or line.get("question_id", None)
            ValueMonitor.remember_qid(qid)

            gt_label = line.get("label", None) or line.get("answer", None) or line.get("gt-label", None)

            image_file = line["image"]
            _, img_ext = osp.splitext(image_file)
            if img_ext is None or img_ext == "":
                image_file = f"{image_file}.jpg"

            qs = line.get("text") or line.get("question")
            assert qs is not None
            cur_prompt = qs
            if model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

            conv = conv_templates[args.conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            # begin inference
            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device=device)
            image = Image.open(os.path.join(args.image_folder, image_file)).convert("RGB")
            image_tensor = process_images([image], image_processor, model.config)[0]
            with torch.inference_mode():
                with torch.no_grad():
                    setattr(model, "tokenizer", tokenizer)
                    outputs = model.generate(
                        input_ids,
                        images=image_tensor.unsqueeze(0).half().to(device),
                        image_sizes=[image.size],
                        return_dict_in_generate=True,
                        output_attentions=True,
                        output_hidden_states=True,
                        do_sample=False,
                        max_new_tokens=1024,
                        use_cache=True,
                    )
                    generated_texts = tokenizer.batch_decode(outputs.sequences)

            # ans_id = shortuuid.uuid()
            ans_file.write(json.dumps({"question_id": qid, "prompt": cur_prompt, "label": gt_label, "response": generated_texts[0], "image": image_file, "model_id": model_name}) + "\n")
            ans_file.flush()

            # save feat
            if FeatureSaver._flag():
                FeatureSaver.save_hs(osp.join(FeatureSaver.get_save_to_path(), "hs.pkl"))
            # metadata save
            if MetadataSaver._flag():
                MetadataSaver.set_gt_label(gt_label)
                MetadataSaver.set_image_path(osp.join(args.image_folder, image_file))
                model_answer_token_ids = outputs.sequences.detach().cpu().clone()
                model_answer = [tokenizer.decode(gen_id) for gen_id in outputs.sequences.detach().cpu().clone()[0]]
                MetadataSaver.set_answer(model_answer_token_ids, model_answer)
                MetadataSaver.set_prompt(prompt)
                MetadataSaver.save_metadata(osp.join(MetadataSaver.get_save_to_path(), "metadata.pkl"))

            # class variable clear
            LogicEngine.clear()
            continue
            # except Exception as e:
            #     print(f"Error occurred at qid-{qid}")
            #     print(e)
            #     continue

def parse_ranges(range_string):
    st = [int(num) for num in range_string.split("-")][:-1]
    ed = [int(num) for num in range_string.split("-")][1:]
    return [[st, ed] for st, ed in zip(st, ed)]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--answer-file-ver", type=str, default="0")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)

    parser.add_argument("--daset", type=str, default=None, required=True)
    parser.add_argument("--gpu", type=int, default=None)

    # store true
    parser.add_argument("--category", type=str, default="")
    parser.add_argument("--save-attention-weights", action="store_true")
    parser.add_argument("--save-metadata", action="store_true")
    parser.add_argument("--save-vision-data", action="store_true")
    parser.add_argument("--save-feat", action="store_true")
    parser.add_argument("--save-attn", action="store_true")

    # logic
    parser.add_argument("--logic", type=str, default=None)
    parser.add_argument("--dim-prospector", action="store_true")
    parser.add_argument("--head-fork", action="store_true")
    parser.add_argument("--var", type=int, default=0)
    parser.add_argument("--sink-rule", type=str, default="ours")
    parser.add_argument("--head-rule", type=str, default="ours")
    
    # ablation
    parser.add_argument("--abl-sink", action="store_true")
    parser.add_argument("--abl-head", action="store_true")

    # tweak
    parser.add_argument("--layerwise-attn-tweak", type=parse_ranges, default=None, help="i-j-k-p")
    parser.add_argument("--layerwise-attn-tweak-value", type=float, default=1.0, help="Range of integers for layerwise attention tweak in the format 'start-end'")

    # exp config
    parser.add_argument("--exp-config", type=str, default=None)
    args = parser.parse_args()

    if args.category == "NONE" or args.category == "none":
        args.category = ""
    eval_model(args)
