'''Llava inference'''

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
from .vlm_saliency_analysis import bbox_mask, extract_all_bbox, bbox_areas, parse_image_id, get_saliency_matrix
from pathlib import Path
import gc
import random
import pandas as pd
import numpy as np
import torch
from .patch_inference import qa_pairs
from .G_and_A_case_study import build_image_dict
import re
from matplotlib.colors import ListedColormap
from typing import Union
import matplotlib.pyplot as plt


def visualize_case(bbox_ls, image_tensor, mask_inside, patch_values, figname: str, save_dir: str, num_patches=16, vmin=None, vmax=None):

    patch_map = patch_values.view(num_patches, num_patches).float().numpy()
    mask_inside = mask_inside.view(num_patches, num_patches).numpy()

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    cmap = plt.get_cmap("hot").copy()
    cmap.set_under("white")
    cmap.set_bad("white")

    if vmin is not None and vmax is not None:
        axes[0].imshow(patch_map, cmap=cmap, vmin=1e-4, vmax=patch_map.max(),
                       extent=[0, W, H, 0])
    else:
        axes[0].imshow(patch_map, cmap=cmap, vmin=1e-4, vmax=patch_map.max(),
                       extent=(0, num_patches, num_patches, 0))

    axes[1].imshow(image_tensor)

    axes[0].axis('off')
    axes[1].axis('off')
    plt.tight_layout()
    Path(save_dir).mkdir(exist_ok=True, parents=True)
    plt.savefig(f'{save_dir}/{figname}.png')
    plt.close()


def overlay_heatmap(orig_img, patch_values, num_patches: int, bbox_ls: list,
                    alpha=0.5, cmap="hot", savepath=None, title=None,
                    vmin: Union[None, float] = None, vmax: Union[None, float] = None):
    H, W, _ = np.array(orig_img).shape
    fig, ax = plt.subplots(figsize=(10, 8))
    patch_map = patch_values.view(num_patches, num_patches).float().numpy()

    original_cmap = plt.get_cmap(cmap)
    new_colors = original_cmap(np.linspace(0, 1, 1000))
    new_colors[0] = [1, 1, 1, 1]
    modified_cmap = ListedColormap(new_colors)
    ax.imshow(orig_img)

    if vmin is not None and vmax is not None:
        im = ax.imshow(patch_map, cmap=modified_cmap, vmin=vmin,
                       vmax=vmax, alpha=alpha, extent=[0, W, H, 0])
    else:
        im = ax.imshow(patch_map, cmap=modified_cmap,
                       alpha=alpha, extent=[0, W, H, 0])

    ax.axis("off")
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=14)

    plt.savefig(savepath, bbox_inches="tight")
    plt.close(fig)


def return_conversation(prompt):
    return [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image"},
            ],
        },
    ]


def processor2img(inputs, name):
    arr = inputs["pixel_values"][0].cpu().numpy()
    arr = arr.transpose(1, 2, 0)  # (H,W,3)
    arr = (arr * 255).clip(0, 255).astype(np.uint8)
    Image.fromarray(arr).save(f"{name}.jpg")


def llava_case(image_dirs, maskout='all', N=3):
    '''filter Gather heads as example'''
    df = pd.read_csv('data/validation-annotations-bbox.csv')
    img2bbox = extract_all_bbox(df)

    model = LlavaForConditionalGeneration.from_pretrained(
        'model/llava/', torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device)
    processor = AutoProcessor.from_pretrained('model/llava/')

    model.eval()

    for image_dir in image_dirs:
        print('dealing with image ', image_dir)
        target_token = str(image_dir).split('/')[-1].split('_')[0]
        image_ls = list(Path(image_dir).glob("*.jpg"))

        text_prompts = random.sample(qa_pairs, N)
        image_paths = random.sample(image_ls, N)

        for i, (prompt, img_path) in enumerate(zip(text_prompts, image_paths)):
            prompt_id = qa_pairs.index(prompt)
            image_id, _ = parse_image_id(str(img_path))
            image = Image.open(img_path).convert("RGB")
            processed_image = image.resize((224, 224))

            # -----------
            target_token_ids = processor.tokenizer(target_token)['input_ids']
            assert len(target_token_ids) >= 2
            target_token_id = target_token_ids[1:2]

            prompt = processor.apply_chat_template(
                return_conversation(prompt.strip()), add_generation_prompt=True)

            inputs = processor(images=image, text=prompt,
                               size=(336, 336), do_center_crop=False,
                               return_tensors="pt").to(device, torch.float16)
            img_token_id = processor.tokenizer("<image>")['input_ids'][1]
            outputs = model(**inputs, output_attentions=True, return_dict=True)
            saliency = get_saliency_matrix(outputs, target_token_id)

            img_mask = (inputs['input_ids'] == img_token_id).squeeze().cpu()
            # ----------
            patch_values = saliency[:, :, -1, img_mask]

            maskout_bbox = img2bbox[image_id] if maskout == 'all' else None
            mask_inside, _ = bbox_mask(
                str(img_path), df, maskout=maskout_bbox, num_patches=24, transpose=True)
            bbox_ls = bbox_areas(df, str(img_path))

            max_val = torch.max(patch_values, dim=-1)
            filter1 = max_val[0] * 2 > torch.sum(patch_values, dim=-1)
            filter2 = mask_inside[max_val[1]]

            print(i, torch.sum(filter1 & filter2))
            if not (filter1 & filter2).any():
                continue

            x, y = torch.where(filter1 & filter2)
            for l, h in zip(x.tolist(), y.tolist()):
                patch_value = patch_values[l, h]

                visualize_case(bbox_ls, processed_image,
                               mask_inside, patch_value,
                               category=target_token,
                               figname=f'{target_token}_img_{image_id}_text{prompt_id}_L{l+1}H{h+1}',
                               save_dir='inference/llava_example',
                               title=f'img_{image_id}_text{prompt_id}_L{l+1}H{h+1}',
                               num_patches=24)

            torch.cuda.empty_cache()
            gc.collect()


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    random.seed(42)
    torch.manual_seed(42)

    import argparse
    parser = argparse.ArgumentParser(
        description='analyze attention flow in vlm')
    parser.add_argument('--mask_out', type=str, default='all',
                        choices=['all', 'current_category'])
    args = parser.parse_args()

    print('Running with args: ', args)
    image_dirs = [path for path in Path(
        'inference/images').iterdir() if 'inpaint' not in path.name]

    llava_case(image_dirs=image_dirs, maskout=args.mask_out)
