import numpy as np
import pandas as pd
import torch
from train_test.trabank_vl.tokenizer.word_tokenizer import CustomWordTokenizer
from train_test.trabank_vl.model import LlavaGPTForCausalLM, LlavaGPT2Config
from PIL import Image
from .vlm_saliency_analysis import get_saliency_matrix, bbox_mask, extract_all_bbox, bbox_areas, parse_image_id
from pathlib import Path
import gc
import random
import json
from collections import defaultdict
from typing import Tuple


qa_pairs = [
    "\nwhat is it ?",
    "\nwhat do you call this ?",
    "\ncan you name this object ?",
    "\nwhat's this called ?",
    "\nwhat this thing is ?",
    "\nwhat would you name this ?",
    "\nwhat's the name of this item ?",
    "\nhow do you identify this ?",
    "\nwhat do we have here ?",
    "\nhow do you call this object ?"
]


def anchor_analysis(saliency, anchor_mask, aggregate_layer: int = -1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    '''Return the word2anchor, anchor2end, word2word information flow of every head.
    For each kind of saliency flow, return a (H, L) tensor.'''
    L, H, S, _ = saliency.shape
    # mask of anchor2end
    mask1 = torch.zeros((S, S), dtype=torch.bool)
    mask1[-1, anchor_mask] = True

    # mask of word2anchor
    mask2 = torch.zeros((S, S), dtype=torch.bool)
    mask2[anchor_mask, :] = True
    mask2 = torch.tril(mask2, diagonal=-1)

    # mask of word2word
    mask3 = torch.tril(torch.ones((S, S), dtype=torch.bool),
                       diagonal=-1) & (~(mask1 | mask2))
    if torch.sum(mask2) > 0:
        word2anchor = torch.mean(saliency[:, :, mask2], dim=-1)
    else:
        word2anchor = torch.zeros((S, S))
    # interesting thing: if we use torch.sum() then divide with torch.sum(mask1)
    # there will be ~1e-4 error
    anchor2end = torch.mean(saliency[:, :, mask1], dim=-1)
    word2word = torch.mean(saliency[:, :, mask3], dim=-1)
    assert anchor2end.shape == (L, H)

    if aggregate_layer >= 0:
        anchor2end_agg_layer = anchor2end[aggregate_layer, :]
        saliency_whole_line = torch.sum(
            saliency[aggregate_layer, :, -1, :], dim=-1)
        aggregate_idx = torch.argmax(anchor2end_agg_layer/saliency_whole_line)
        return word2anchor, anchor2end, word2word, int(aggregate_idx.item())
    else:
        return word2anchor, anchor2end, word2word, -1


def token2end_analysis(saliency, mask_inside_bbox, mask_outside_bbox, mask_image):
    def get_attn(mask: torch.Tensor):
        '''only return token2end flow'''
        return torch.mean(saliency[:, -1, mask], dim=-1)

    attn_in = get_attn(mask_inside_bbox)
    attn_out = get_attn(mask_outside_bbox)
    if ((mask_outside_bbox | mask_inside_bbox) == mask_image).all():
        attn_other_image = torch.zeros_like(attn_in)
    else:
        attn_other_image = get_attn(mask_image & ~(
            mask_outside_bbox | mask_inside_bbox))
    attn_text = get_attn(~mask_image)
    if torch.isnan(attn_in).any() or torch.isnan(attn_out).any() or torch.isnan(attn_other_image).any():
        raise ValueError('Nan in token2end analysis')

    return attn_in, attn_out, attn_other_image, attn_text


def run_inference(model, tokenizer, ckpt: int, image_dirs: list[Path], maskout='all',
                  random_seed=None, aggregate_layer: int = -1, batch_size=10):
    model.eval()
    df = pd.read_csv('data/validation-annotations-bbox.csv')
    img2bbox = extract_all_bbox(df)
    total_fig_cnt = 0
    aggregate_head_dict = defaultdict(lambda: defaultdict(int))

    for image_dir in image_dirs:
        print('dealing with image ', image_dir)
        category = str(image_dir).split('/')[-1].split('_')[0]
        image_ls = list(Path(image_dir).glob("*.jpg"))
        target_token = category
        target_token_id = tokenizer.encode(
            target_token, add_special_tokens=False)
        assert len(target_token_id) == 3
        target_token_id = target_token_id[1:-1]

        for image_path in image_ls:
            image_id, _ = parse_image_id(str(image_path))
            image = Image.open(image_path).convert("RGB")
            image_tensor = model.transformer.get_vision_tower().image_processor(
                image).unsqueeze(0).to(device, dtype=model.dtype)  # [1, 1, 3, 224, 224]

            maskout_bbox = img2bbox[image_id] if maskout == 'all' else None
            mask_inside, mask_outside = bbox_mask(
                str(image_path), df, maskout=maskout_bbox, transpose=True)
            if torch.sum(mask_inside).item() == 0 or torch.sum(mask_outside).item() == 0:
                print('Skipped image due to patch', image_path)
                continue

            for prompt_id, text_prompt in enumerate(qa_pairs):

                stat_pth = f'saliency_correct/stat/{random_seed}/ckpt_{ckpt}/{category}'
                if Path(f"{stat_pth}/img_id_{image_id}_text_{prompt_id}.pt").exists():
                    print(
                        f"{stat_pth}/img_id_{image_id}_text_{prompt_id}.pt exists, continue")
                    total_fig_cnt += 1
                    continue

                print(prompt_id, text_prompt)
                inputs = tokenizer(text_prompt, return_tensors="pt").to(device)
                inp_id = inputs["input_ids"]
                inp_ids = torch.cat([inp_id[:, :1], torch.tensor(  # type: ignore
                    [[-200]], device=device), inp_id[:, 1:-1]], dim=1)   # type: ignore

                outputs = model(
                    inp_ids, inputs["attention_mask"], images=image_tensor)
                saliency = get_saliency_matrix(
                    outputs, target_token_id)  # (L, H, S, S)
                del outputs

                saliency_per_layer = torch.sum(saliency, dim=1)
                S = saliency_per_layer.shape[-1]
                bbox_ls = bbox_areas(df, str(image_path))
                if not len(bbox_ls):
                    print('bbox not found!')
                    break

                patch_total = torch.numel(mask_inside)
                mask_inside_bbox = torch.zeros(S, dtype=torch.bool)
                mask_inside_bbox[1:1+patch_total] = mask_inside

                mask_outside_bbox = torch.zeros(S, dtype=torch.bool)
                mask_outside_bbox[1:1+patch_total] = mask_outside

                mask_image = torch.zeros(S, dtype=torch.bool)
                mask_image[1:1+patch_total] = True

                try:
                    attn_in, attn_out, attn_other_image, attn_text = token2end_analysis(
                        saliency_per_layer, mask_inside_bbox, mask_outside_bbox, mask_image)
                except ValueError:
                    print('Nan in token2end analysis, image', image_path)
                    break

                word2anchor, anchor2end, word2word, aggregate_head = anchor_analysis(
                    saliency, mask_inside_bbox, aggregate_layer)
                layerwise_word2anchor = torch.sum(word2anchor, dim=1)
                layerwise_anchor2end = torch.sum(anchor2end, dim=1)
                layerwise_word2word = torch.sum(word2word, dim=1)
                try:
                    assert torch.norm(attn_in-layerwise_anchor2end) < 2e-4
                except:
                    print('@@@ Inproper threshold! @@@')

                if aggregate_layer >= 0:
                    assert aggregate_head >= 0
                    aggregate_head_dict[category][f'img_id_{image_id}_text_{prompt_id}'] = aggregate_head

                    G_A_pth = f'saliency_correct/aggregate_layer{aggregate_layer}/{random_seed}/ckpt_{ckpt}/{category}'
                    Path(G_A_pth).mkdir(parents=True, exist_ok=True)
                    torch.save(saliency[aggregate_layer],
                               f'{G_A_pth}/img_id_{image_id}_text_{prompt_id}.pt')

                Path(stat_pth).mkdir(parents=True, exist_ok=True)
                torch.save({'img_patch_in_bbox_2_end': layerwise_anchor2end,
                            'img_patch_out_bbox_2_end': attn_out,
                            'other_img_patch_2_end': attn_other_image,
                            'text_token_2_end': attn_text,
                            'word2anchor': layerwise_word2anchor,
                            'word2word': layerwise_word2word},
                           f'{stat_pth}/img_id_{image_id}_text_{prompt_id}.pt')
                # print(f'saving {stat_pth}/img_id_{image_id}_text_{prompt_id}.pt')

                torch.cuda.empty_cache()
                gc.collect()
                total_fig_cnt += 1

    print(f'Finished: {total_fig_cnt} figures in all!')
    if aggregate_layer >= 0:
        with open(f'inference/GA_{random_seed}_vlm.json', 'w', encoding='utf-8') as f:
            json.dump(aggregate_head_dict, f, ensure_ascii=False, indent=2)


def main(image_dirs, random_seed='trabank_vl_dino_pretrain2', ckpt=300000,
         maskout='all', aggregate_layer=-1):
    model_path = f"model/{random_seed}/checkpoint-{ckpt}/"
    print('current model:', model_path)
    config_ = LlavaGPT2Config.from_pretrained(model_path)
    config_.output_attentions = True
    config_.return_dict = True
    model = LlavaGPTForCausalLM.from_pretrained(model_path, config=config_).to(
        device, dtype=torch.bfloat16)  # type: ignore

    tokenizer = CustomWordTokenizer.from_pretrained(model_path)

    run_inference(model, tokenizer, ckpt, image_dirs=image_dirs, maskout=maskout,
                  random_seed=random_seed, aggregate_layer=aggregate_layer)


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(42)
    torch.manual_seed(42)

    import argparse
    parser = argparse.ArgumentParser(
        description='analyze attention flow in vlm')
    parser.add_argument('--random_seed', type=str, default='trabank_vl_dino_pretrain_resize_pure_s42',
                        help='name of the item to analyze')
    parser.add_argument('--checkpoint', type=int, default=300000,
                        help='the training step of checkpoints')
    parser.add_argument('--mask_out', type=str, default='all',
                        choices=['all', 'current_category'])
    parser.add_argument('--aggregate_layer', type=int, default=-1)
    args = parser.parse_args()

    print('Running with args: ', args)

    image_dirs = [path for path in Path(
        'inference/images').iterdir() if 'inpaint' not in path.name]
    main(image_dirs=image_dirs, random_seed=args.random_seed,
         ckpt=args.checkpoint, maskout=args.mask_out, aggregate_layer=args.aggregate_layer)
