import argparse
import torch
import os
import json
import pandas as pd
from tqdm import tqdm
import shortuuid
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List

from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
from llava.visualize_utils import show_img_and_mask, show_img

from PIL import Image
import math
import numpy as np
from pycocotools import mask as mask_utils
import torch.nn.functional as F

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen3-0.6B")
    version: Optional[str] = field(default="qwen3")
    log_debug: bool = field(default=False)
    shuffle_patches: bool = field(default=False)
    region_based: bool = field(default=True)
    region_sort: str = field(default='com_patch')
    region_source: str = field(default='clustering')
    region_cluster_args: str = field(default='t=0.7,m=1')
    region_late_pe: Optional[bool] = field(default=False)
    region_interpolate: Optional[str] = field(default='downsample_pad')
    region_pooling_method: Optional[str] = field(default='average')
    region_attn_args: Optional[str] = field(default=None)
    region_extra: Optional[str] = field(default='none')
    region_filter: Optional[str] = field(default='none')
    region_expand_mult: Optional[int] = field(default=1)
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-2)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=False)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="patch")
    mm_force_imsize: Optional[int] = field(default=None)
    mm_vision_feature_pe: str = field(default='none')

def load_model(model_args: ModelArguments, use_flash_attn=False, image_aspect_ratio="pad"):
    if use_flash_attn:
        attn_implementation = 'flash_attention_2'
    else:
        attn_implementation = None

    if model_args.vision_tower is not None:
        from llava.model.builder import get_model_class
        model_class = get_model_class(model_args.model_name_or_path,model_args.region_based)
        model_configs = {
            "log_debug": model_args.log_debug
        }
        if model_args.region_based:
            region_attn_args = json.load(open(model_args.region_attn_args, "r")) if model_args.region_attn_args is not None else None
            model_configs.update({
                "region_sort": model_args.region_sort,
                "region_source": model_args.region_source,
                "region_cluster_args": model_args.region_cluster_args,
                "region_late_pe": model_args.region_late_pe,
                "region_interpolate": model_args.region_interpolate,
                "region_pooling_method": model_args.region_pooling_method,
                "region_attn_args": region_attn_args,
                "region_extra": model_args.region_extra,
                "region_filter": model_args.region_filter,
                "region_expand_mult": model_args.region_expand_mult,
            })
        else:
            model_configs.update({
                "shuffle_patches": model_args.shuffle_patches,
            })
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            attn_implementation=attn_implementation,
            torch_dtype=torch.float16,
            **model_configs,
        )

    if model_args.vision_tower is not None:
        model.get_model().initialize_vision_modules(
            model_args=model_args,
        )
        
        vision_tower = model.get_vision_tower()

        model.config.image_aspect_ratio = image_aspect_ratio
        model.config.mm_use_im_start_end = model_args.mm_use_im_start_end
        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
    
    return model.half().cuda().eval()


def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
    # copied from Denoising-ViT: https://github.com/Jiawei-Yang/Denoising-ViT
    # features: (N, C)
    # m: a hyperparam controlling how many std dev outside for outliers
    assert len(features.shape) == 2, "features should be (N, C)"
    reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
    S = torch.pca_lowrank(features, q=5, niter=20)[1]
    print(f"Top PCA components: ", S)
    colors = features @ reduction_mat
    if remove_first_component:
        colors_min = colors.min(dim=0).values
        colors_max = colors.max(dim=0).values
        tmp_colors = (colors - colors_min) / (colors_max - colors_min)
        fg_mask = tmp_colors[..., 0] < 0.2
        reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
        colors = features @ reduction_mat
    else:
        fg_mask = torch.ones_like(colors[:, 0]).bool()
    d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    try:
        rins = colors[fg_mask][s[:, 0] < m, 0]
        gins = colors[fg_mask][s[:, 1] < m, 1]
        bins = colors[fg_mask][s[:, 2] < m, 2]
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
    except:
        rins = colors
        gins = colors
        bins = colors
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])

    return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)

def visualize_features(save_dir: str, features: dict, model, name="postproj"):
    features_id = list(features.keys())
    features_tensor = torch.stack([features[i] for i in features_id], dim=0).double()
    N, M, C = features_tensor.shape
    A=features_tensor.view(-1, features_tensor.shape[-1])
    # Q=3
    # U, S, V = torch.pca_lowrank(A, q=Q+1, niter=5)
    # print(f"Top PCA components for {name}: ", S)
    # # estimated = torch.matmul(U, torch.diag(S))
    # # actual = (A - A.mean(dim=0, keepdim=True)) @ V
    # colors = (A @ V)[:, :Q] # do not need to center, as will be normalized to [0, 1]
    # color_max, color_min = colors.max(dim=0, keepdim=True).values, colors.min(dim=0, keepdim=True).values
    reduct_mat, color_min, color_max = get_robust_pca(A)
    colors = A @ reduct_mat
    colors = ((colors-color_min) / (color_max - color_min)).clamp(0, 1)
    
    pps = model.get_vision_tower().num_patches_per_side
    img_size = model.get_vision_tower().config.image_size
    patch_size = model.get_vision_tower().config.patch_size
    H, W = img_size, img_size
    feature_lowrank = colors[:, :3].view(N, pps, pps, 3)
    visualized = feature_lowrank.permute(0, 3, 1, 2).contiguous()
    visualized = F.interpolate(visualized, scale_factor=patch_size, mode='nearest')
    visualized = F.pad(visualized, (0, H-pps*patch_size, 0, W-pps*patch_size), mode='constant', value=0)
    visualized = visualized.permute(0, 2, 3, 1).contiguous()
    for id, visual in zip(features_id, visualized):
        show_img(visual, save_path=os.path.join(save_dir, f"pca-{name}-{id}.png"))

def eval_model(model, image_processor, questions:pd.DataFrame, mask_folder: str, save_orig_img=False, save_file: str = "./playground/"):
    save_file = os.path.expanduser(save_file)
    attn_path=os.path.join(os.path.dirname(save_file), "visualization", os.path.splitext(os.path.basename(save_file))[0])
    os.makedirs(attn_path, exist_ok=True)
    all_visual_features_raw = {}

    for index, row in tqdm(questions.iterrows(), total=len(questions)):
        idx = row['index']
        image = load_image_from_base64(row['image'])
        if mask_folder is not None:
            mask_file = f"{idx}.json"
            with open(os.path.join(mask_folder, mask_file), "r") as f:
                sam_masks = json.load(f)
            if len(sam_masks) > 0:
                sam_masks = mask_utils.decode([m["segmentation"] for m in sam_masks])
                sam_masks = np.moveaxis(sam_masks, -1, 0)
            else:
                sam_masks = np.zeros((0, 1, 1), dtype=np.uint8)
            if sam_masks.shape[0] == 0:
                print(f"No mask found for {idx}")
            sam_masks = torch.tensor(sam_masks)

        image_tensor = process_images([image], image_processor, model.config)[0]

        with torch.inference_mode():
            all_args = dict(
                images=image_tensor.unsqueeze(0).half().cuda(),
                sam_masks=[sam_masks.to(device='cuda', non_blocking=True)],
                return_masks=True
            )
            all_outputs = model.encode_images(**all_args)
            visual_features, _m, nonzero_masks = model._cached_masks
            model._cached_masks = None
            visual_features = visual_features[0]
            nonzero_masks = nonzero_masks[0]
            visual_features_raw = model.get_vision_tower()(image_tensor.unsqueeze(0).half().cuda())[0]
            raw_features_norm = visual_features_raw.norm(dim=-1).cpu().numpy()
            raw_features_norm = raw_features_norm / raw_features_norm.max()

        if mask_folder is not None and model.region_source == "passed":
            img = np.array(image)
            white_img = np.ones_like(img) * 255
            sam_masks = model.sort_regions(sam_masks)
            sam_masks = model.add_extra_regions(sam_masks)
            sam_masks = model.filter_regions(sam_masks)
            assert sam_masks.shape[0] == nonzero_masks.shape[0]
            sam_masks = sam_masks[nonzero_masks.cpu()]
        else:
            img = image_tensor.permute(1,2,0).float()
            img = (img-img.min())/(img.max()-img.min()) * 255
            img = img.cpu().numpy().astype(np.uint8)
            white_img = np.ones_like(img) * 230
            pps = model.get_vision_tower().num_patches_per_side
            if mask_folder is not None and (model.region_source == "clustering" or model.region_source.startswith("split")):
                sam_masks = _m[0][nonzero_masks].to(device="cpu", dtype=torch.uint8)
            else:
                sam_masks = torch.eye(pps**2, dtype=torch.uint8).view(-1, pps, pps)
            patch_size = model.get_vision_tower().config.patch_size
            h, w = image_tensor.shape[-2:]
            sam_masks = F.interpolate(sam_masks.unsqueeze(0), scale_factor=patch_size, mode='nearest').squeeze(0)
            sam_masks = F.pad(sam_masks, (0, h-pps*patch_size, 0, w-pps*patch_size), mode='constant', value=0)

        num_special_tokens = 1*("cls" in model.get_vision_tower().select_feature) + 4*("reg" in model.get_vision_tower().select_feature)
        if "sum" in model.get_vision_tower().select_feature:
            num_special_tokens += model.get_vision_tower().config.summary_len
        sid = 1 if mask_folder is not None and 'global' in model.region_extra else 0
        show_img_and_mask(img, sam_masks.bool().cpu().numpy()[sid:], save_path=os.path.join(attn_path, f"regions-{idx}.png"))
        patchs = torch.eye(pps**2, dtype=torch.uint8).view(-1, pps, pps)
        patch_size = model.get_vision_tower().config.patch_size
        h, w = image_tensor.shape[-2:]
        patchs = F.interpolate(patchs.unsqueeze(0), scale_factor=patch_size, mode='nearest').squeeze(0)
        patchs = F.pad(patchs, (0, h-pps*patch_size, 0, w-pps*patch_size), mode='constant', value=0)
        show_img_and_mask(white_img, patchs.bool().cpu().numpy(), weights=raw_features_norm[num_special_tokens:], 
            divide_weights_by_area=False, save_path=os.path.join(attn_path, f"norm-{idx}.png"))
        all_visual_features_raw[idx] = visual_features_raw[num_special_tokens:]
        if save_orig_img:
            show_img_and_mask(img, [], save_path=os.path.join(attn_path, f"img-{idx}.png"))

    visualize_features(attn_path, all_visual_features_raw, model, name="raw")

encoders = {
    "clip": {
        "vision_tower": "openai/clip-vit-large-patch14-336",
        "mm_force_imsize": None,
    },
    "siglip2": {
        "vision_tower": "google/siglip2-so400m-patch14-384",
        "mm_force_imsize": None,
    },
    "aimv2L-448x": {
        "vision_tower": "apple/aimv2-large-patch14-448",
        "mm_force_imsize": None,
    },
    "dinov2L-518x": {
        "vision_tower": "timm/vit_large_patch14_dinov2.lvd142m",
        "mm_force_imsize": None,
    },
    "dinov2L-518x-reg": {
        "vision_tower": "timm/vit_large_patch14_reg4_dinov2.lvd142m",
        "mm_force_imsize": None,
    },
    "radio": {
        "vision_tower": "radio_v2.5-l",
        "mm_force_imsize": 384,
    },
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--mask-folder", type=str, default=None)
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--save-path", type=str, default="playground/")
    parser.add_argument("--category", type=str, default="['attribute_recognition', 'nature_relation', 'physical_relation']")
    args = parser.parse_args()

    questions = pd.read_table(os.path.expanduser(args.question_file))
    if args.category is not None and args.category!="all":
        all_categories = questions['category'].unique()
        category = eval(args.category)
        questions = questions[questions['category'].isin(category)]
        questions = questions[questions['index'] < 1000000]

    # Model
    disable_torch_init()
    for encoder, model_args in encoders.items():
        model = load_model(ModelArguments(**model_args))
        image_processor = model.get_vision_tower().image_processor
        save_path = os.path.join(args.save_path, f"{encoder}/quick.jsonl")
        eval_model(model, image_processor, questions, args.mask_folder, save_file=save_path)
