import argparse
from tracemalloc import start
from matplotlib import scale
import torch
import os
import json
import pandas as pd
from tqdm import tqdm
import shortuuid

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
from llava.visualize_utils import show_img_and_mask, show_img

from PIL import Image
import math
import numpy as np
from pycocotools import mask as mask_utils
import torch.nn.functional as F


all_options = ['A', 'B', 'C', 'D']


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


def is_none(value):
    if value is None:
        return True
    if type(value) is float and math.isnan(value):
        return True
    if type(value) is str and value.lower() == 'nan':
        return True
    if type(value) is str and value.lower() == 'none':
        return True
    return False

def get_options(row, options):
    parsed_options = []
    for option in options:
        option_value = row[option]
        if is_none(option_value):
            break
        parsed_options.append(option_value)
    return parsed_options


def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
    # copied from Denoising-ViT: https://github.com/Jiawei-Yang/Denoising-ViT
    # features: (N, C)
    # m: a hyperparam controlling how many std dev outside for outliers
    assert len(features.shape) == 2, "features should be (N, C)"
    reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
    S = torch.pca_lowrank(features, q=5, niter=20)[1]
    print(f"Top PCA components: ", S)
    colors = features @ reduction_mat
    if remove_first_component:
        colors_min = colors.min(dim=0).values
        colors_max = colors.max(dim=0).values
        tmp_colors = (colors - colors_min) / (colors_max - colors_min)
        fg_mask = tmp_colors[..., 0] < 0.2
        reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
        colors = features @ reduction_mat
    else:
        fg_mask = torch.ones_like(colors[:, 0]).bool()
    d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    try:
        rins = colors[fg_mask][s[:, 0] < m, 0]
        gins = colors[fg_mask][s[:, 1] < m, 1]
        bins = colors[fg_mask][s[:, 2] < m, 2]
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
    except:
        rins = colors
        gins = colors
        bins = colors
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])

    return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)

def visualize_features(save_dir: str, features: dict, model, name="postproj"):
    features_id = list(features.keys())
    features_tensor = torch.stack([features[i] for i in features_id], dim=0).double()
    N, M, C = features_tensor.shape
    A=features_tensor.view(-1, features_tensor.shape[-1])
    # Q=3
    # U, S, V = torch.pca_lowrank(A, q=Q+1, niter=5)
    # print(f"Top PCA components for {name}: ", S)
    # # estimated = torch.matmul(U, torch.diag(S))
    # # actual = (A - A.mean(dim=0, keepdim=True)) @ V
    # colors = (A @ V)[:, :Q] # do not need to center, as will be normalized to [0, 1]
    # color_max, color_min = colors.max(dim=0, keepdim=True).values, colors.min(dim=0, keepdim=True).values
    reduct_mat, color_min, color_max = get_robust_pca(A)
    colors = A @ reduct_mat
    colors = ((colors-color_min) / (color_max - color_min)).clamp(0, 1)
    
    pps = model.get_vision_tower().num_patches_per_side
    img_size = model.get_vision_tower().config.image_size
    patch_size = model.get_vision_tower().config.patch_size
    H, W = img_size, img_size
    feature_lowrank = colors[:, :3].view(N, pps, pps, 3)
    visualized = feature_lowrank.permute(0, 3, 1, 2).contiguous()
    visualized = F.interpolate(visualized, scale_factor=patch_size, mode='nearest')
    visualized = F.pad(visualized, (0, H-pps*patch_size, 0, W-pps*patch_size), mode='constant', value=0)
    visualized = visualized.permute(0, 2, 3, 1).contiguous()
    for id, visual in zip(features_id, visualized):
        show_img(visual, save_path=os.path.join(save_dir, f"pca-{name}-{id}.png"))

def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    model_kwargs = json.loads(args.model_kwargs) if args.model_kwargs else {}
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, **model_kwargs)

    questions = pd.read_table(os.path.expanduser(args.question_file))
    if args.category is not None and args.category!="all":
        all_categories = questions['category'].unique()
        questions = questions[questions['category'] == args.category]
        questions = questions[questions['index'] < 1000000]
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    answers_file = os.path.expanduser(args.answers_file)
    attn_path=os.path.join(os.path.dirname(answers_file), "visualization", os.path.splitext(os.path.basename(answers_file))[0])
    os.makedirs(attn_path, exist_ok=True)
    ans_file = open(answers_file, "w")
    attn_top_values = []
    all_visual_features_postproj = {}
    all_visual_features_raw = {}
    save_orig_img = model_path.endswith("radio/llava-v1.5-7b-layer2")

    if 'qwen3' in model_name.lower():
        args.conv_mode = 'qwen3'
        print(f'It seems that this is a qwen3 model, auto switching to {args.conv_mode}.')
    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')

    for index, row in tqdm(questions.iterrows(), total=len(questions)):
        options = get_options(row, all_options)
        cur_option_char = all_options[:len(options)]

        if args.all_rounds:
            num_rounds = len(options)
        else:
            num_rounds = 1

        for round_idx in range(num_rounds):
            idx = row['index']
            question = row['question']
            hint = row['hint']
            image = load_image_from_base64(row['image'])
            if args.mask_folder is not None:
                mask_file = f"{idx}.json"
                with open(os.path.join(args.mask_folder, mask_file), "r") as f:
                    sam_masks = json.load(f)
                if len(sam_masks) > 0:
                    sam_masks = mask_utils.decode([m["segmentation"] for m in sam_masks])
                    sam_masks = np.moveaxis(sam_masks, -1, 0)
                else:
                    sam_masks = np.zeros((0, 1, 1), dtype=np.uint8)
                if sam_masks.shape[0] == 0:
                    print(f"No mask found for {idx}")
                sam_masks = torch.tensor(sam_masks)

            if not is_none(hint):
                question = hint + '\n' + question
            for option_char, option in zip(all_options[:len(options)], options):
                question = question + '\n' + option_char + '. ' + option
            qs = cur_prompt = question
            if model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

            if args.single_pred_prompt:
                if args.lang == 'cn':
                    qs = qs + '\n' + "请直接回答选项字母。"
                else:
                    qs = qs + '\n' + "Answer with the option's letter from the given choices directly."

            conv = conv_templates[args.conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

            image_tensor = process_images([image], image_processor, model.config)[0]

            with torch.inference_mode():
                input_args = dict(
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    image_sizes=[image.size],
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature if args.temperature > 0 else None,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    # no_repeat_ngram_size=3,
                    max_new_tokens=1024,
                    use_cache=True
                )
                if args.mask_folder is not None:
                    input_args["sam_masks"] = [sam_masks.to(device='cuda', non_blocking=True)]
                
                output_ids = model.generate(input_ids, **input_args)

                all_ids = torch.cat([input_ids, output_ids], dim=1)
                all_args = dict(
                    output_attentions = True,
                    images=image_tensor.unsqueeze(0).half().cuda(),
                    image_sizes=[image.size],
                )
                if args.mask_folder is not None:
                    all_args["sam_masks"] = [sam_masks.to(device='cuda', non_blocking=True)]
                    all_args["return_masks"] = True
                all_outputs = model.forward(all_ids, **all_args)
                if args.mask_folder is not None:
                    visual_features, _m, nonzero_masks = model._cached_masks
                    model._cached_masks = None
                    visual_features = visual_features[0]
                    nonzero_masks = nonzero_masks[0]
                else:
                    visual_features = model.encode_images(image_tensor.unsqueeze(0).half().cuda())[0]
                visual_features_raw = model.get_vision_tower()(image_tensor.unsqueeze(0).half().cuda())[0]
                visual_features_norm = visual_features.norm(dim=-1).cpu().numpy()
                visual_features_norm = visual_features_norm / visual_features_norm.max()

            if args.mask_folder is not None and model.region_source == "passed":
                img = np.array(image)
                white_img = np.ones_like(img) * 255
                sam_masks = model.sort_regions(sam_masks)
                sam_masks = model.add_extra_regions(sam_masks)
                sam_masks = model.filter_regions(sam_masks)
                assert sam_masks.shape[0] == nonzero_masks.shape[0]
                sam_masks = sam_masks[nonzero_masks.cpu()]
            else:
                img = image_tensor.permute(1,2,0).float()
                img = (img-img.min())/(img.max()-img.min()) * 255
                img = img.cpu().numpy().astype(np.uint8)
                white_img = np.ones_like(img) * 230
                pps = model.get_vision_tower().num_patches_per_side
                if args.mask_folder is not None and (model.region_source == "clustering" or model.region_source.startswith("split")):
                    sam_masks = _m[0][nonzero_masks].to(device="cpu", dtype=torch.uint8)
                else:
                    sam_masks = torch.eye(pps**2, dtype=torch.uint8).view(-1, pps, pps)
                patch_size = model.get_vision_tower().config.patch_size
                h, w = image_tensor.shape[-2:]
                sam_masks = F.interpolate(sam_masks.unsqueeze(0), scale_factor=patch_size, mode='nearest').squeeze(0)
                sam_masks = F.pad(sam_masks, (0, h-pps*patch_size, 0, w-pps*patch_size), mode='constant', value=0)
            img_token_pos = torch.where(all_ids[0] == IMAGE_TOKEN_INDEX)[0][0]
            num_special_tokens = 1*("cls" in model.get_vision_tower().select_feature) + 4*("reg" in model.get_vision_tower().select_feature)
            if "sum" in model.get_vision_tower().select_feature:
                num_special_tokens += model.get_vision_tower().config.summary_len
            img_token_pos += num_special_tokens
            all_attentions = torch.stack(all_outputs['attentions'], dim=0)
            avg_attentions = all_attentions.view(-1, all_attentions.shape[-2], all_attentions.shape[-1]).mean(dim=0)
            answer_mask_attentions = avg_attentions[-output_ids.shape[1]-1:-1,img_token_pos:img_token_pos+len(sam_masks)]
            assert all_ids.shape[1] - 1 + len(sam_masks) + num_special_tokens == avg_attentions.shape[0]
            sid = 1 if args.mask_folder is not None and 'global' in model.region_extra else 0
            if sid == 1: # mark the global mask as a top-left cross
                sam_masks[0] = 0
                G_pattern = torch.FloatTensor([[1,1,1,1],[1,0,0,0],[1,0,1,1],[1,0,0,1],[1,1,1,1]])
                G_pattern = F.interpolate(G_pattern.unsqueeze(0).unsqueeze(0), scale_factor=2, mode='nearest').squeeze(0).squeeze(0).bool()
                sam_masks[0,:G_pattern.shape[0],:G_pattern.shape[1]] = G_pattern
            answer_mask_attentions = answer_mask_attentions[0].cpu().numpy()
            answer_mask_attentions_reweighted = answer_mask_attentions / answer_mask_attentions[sid:].max() * 0.5
            attentions_top10 = np.sort(answer_mask_attentions / answer_mask_attentions.sum() * len(answer_mask_attentions))[-10:][::-1]
            if len(attentions_top10) < 10:
                attentions_top10 = np.pad(attentions_top10, (0, 10-len(attentions_top10)), constant_values=0)
            attn_top_values.append(attentions_top10)
            answer_mask_attentions_reweighted[answer_mask_attentions_reweighted<answer_mask_attentions_reweighted[sid:].mean()] = 0
            include_global_mask = True
            sidx = 0 if include_global_mask else sid
            # answer_mask_attentions_reweighted = answer_mask_attentions / answer_mask_attentions.sum()
            # answer_mask_attentions_reweighted = (answer_mask_attentions_reweighted - answer_mask_attentions_reweighted.mean()) * 2 + answer_mask_attentions_reweighted.mean()
            if args.mask_folder is not None:
                show_img_and_mask(img, sam_masks.bool().cpu().numpy()[sid:], save_path=os.path.join(attn_path, f"regions-{idx}.png"))
            show_img_and_mask(img, sam_masks.bool().cpu().numpy()[sidx:], weights=answer_mask_attentions_reweighted.clip(0,1)[sidx:], 
                divide_weights_by_area=False, save_path=os.path.join(attn_path, f"attention-{idx}.png"))
            show_img_and_mask(white_img, sam_masks.bool().cpu().numpy()[sidx:], weights=visual_features_norm[sidx+num_special_tokens:], 
                divide_weights_by_area=False, save_path=os.path.join(attn_path, f"norm-{idx}.png"))
            all_visual_features_postproj[idx] = visual_features[sid+num_special_tokens:]
            all_visual_features_raw[idx] = visual_features_raw[num_special_tokens:]
            if save_orig_img:
                show_img_and_mask(img, [], save_path=os.path.join(attn_path, f"img-{idx}.png"))
                

            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

            ans_id = shortuuid.uuid()
            ans_file.write(json.dumps({"question_id": idx,
                                    "round_id": round_idx,
                                    "prompt": cur_prompt,
                                    "text": outputs,
                                    "options": options,
                                    "option_char": cur_option_char,
                                    "answer_id": ans_id,
                                    "model_id": model_name,
                                    "metadata": {}}) + "\n")
            ans_file.flush()

            # rotate options
            options = options[1:] + options[:1]
            cur_option_char = cur_option_char[1:] + cur_option_char[:1]
    
    if args.mask_folder is not None:
        print("PCA feature visualization is not implemented for region-based features, skipping.")
    else:
        visualize_features(attn_path, all_visual_features_postproj, model, name="postproj")
        visualize_features(attn_path, all_visual_features_raw, model, name="raw")
    attn_top_values = np.array(attn_top_values)
    print("avg top-10 attention values: ", attn_top_values.mean(axis=0))
    with open(os.path.join(attn_path, "top10_attention.json"), "w") as f:
        json.dump(attn_top_values.mean(axis=0).tolist(), f)
    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--model-kwargs", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--mask-folder", type=str, default=None)
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--category", type=str, default=None)
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--all-rounds", action="store_true")
    parser.add_argument("--single-pred-prompt", action="store_true")
    parser.add_argument("--lang", type=str, default="en")
    args = parser.parse_args()

    eval_model(args)
