from transformers import AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
from qwen_vl_utils import process_vision_info

import torch
import json
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
from collections import Counter
import numpy as np
import math
from typing import List, Tuple, Optional
import re

import random
seed = 17
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# helper functions for sentence chunking and alignment
def get_chunk_idx(sentence_chunks, idx):
    for i, chunk in enumerate(sentence_chunks):
        if not chunk: continue
        if chunk[0] <= idx and idx <= chunk[1]: return i
    return -1

def _sent_spans_via_find(text: str) -> Tuple[List[str], List[Tuple[int, int]]]:
    """Return (sentences, [(start,end)]) where spans are character indices in `text`."""
    sents = sent_tokenize(text)
    spans = []
    cursor = 0
    for s in sents:
        i = text.find(s, cursor)
        if i == -1:
            i = cursor  # rare fallback, preserves order
        spans.append((i, i + len(s)))
        cursor = i + len(s)
    return sents, spans

def _kmp_find_subsequence(hay: List[int], needle: List[int]) -> Optional[int]:
    """Find start index where `needle` occurs inside `hay` using KMP. Return None if not found."""
    if not needle:
        return 0
    lps = [0] * len(needle)
    j = 0
    for i in range(1, len(needle)):
        while j > 0 and needle[i] != needle[j]:
            j = lps[j - 1]
        if needle[i] == needle[j]:
            j += 1
            lps[i] = j
    j = 0
    for i, t in enumerate(hay):
        while j > 0 and t != needle[j]:
            j = lps[j - 1]
        if t == needle[j]:
            j += 1
            if j == len(needle):
                return i - j + 1
    return None

def _align_enc_ids_with_target_ids(enc_ids: List[int], target_ids: List[int]) -> Optional[int]:
    """
    Try to align two id sequences. Returns the start index in `target_ids`
    where `enc_ids` occurs (as a contiguous subsequence). Handles padding/EOS.
    """
    # Fast path: exact prefix
    if len(enc_ids) <= len(target_ids) and target_ids[:len(enc_ids)] == enc_ids:
        return 0
    # General path: subsequence
    return _kmp_find_subsequence(target_ids, enc_ids)

def compute_sentence_chunks(
    *,
    tokenizer,
    base_text: str,            # the exact text you want to align offsets against
    ids,                       # list or 1D tensor of token ids you want indexed (e.g., input_ids or generated_ids)
    subtext: Optional[str] = None,  # if provided, will chunk sentences only within this subtext slice of base_text
) -> List[Optional[Tuple[int, int]]]:
    """
    Returns: List[Optional[(start_idx, end_idx)]], aligned with sent_tokenize(subtext or base_text).
             None means that sentence is not present (e.g., truncated).
    Works for:
      • Source inside chat template: base_text = chat string, subtext = source
      • Generated summary:          base_text = output_text,   subtext = None (whole string)
    """
    if hasattr(ids, "tolist"):
        ids = ids.tolist()

    # 1) Decide the character slice to chunk
    if subtext is None:
        sub_char_start, sub_char_end = 0, len(base_text)
        sents, sent_spans = _sent_spans_via_find(base_text)
    else:
        sub_char_start = base_text.rfind(subtext)
        if sub_char_start == -1:
            # Mismatch; cannot safely attribute
            return []
        sub_char_end = sub_char_start + len(subtext)
        _, sent_spans_local = _sent_spans_via_find(subtext)
        sents = [subtext[a:b] for (a, b) in sent_spans_local]
        # Convert local spans to base_text coordinates
        sent_spans = [(sub_char_start + a, sub_char_start + b) for (a, b) in sent_spans_local]

    # 2) Tokenize base_text with offsets. Prefer no special tokens first (often matches generated),
    #    then fall back to add_special_tokens=True if needed.
    for add_special in (False, True):
        enc = tokenizer(
            base_text,
            add_special_tokens=add_special,
            return_offsets_mapping=True,
            padding=False,
            truncation=False,
        )
        enc_ids = enc["input_ids"]
        offsets = enc["offset_mapping"]

        start_in_ids = _align_enc_ids_with_target_ids(enc_ids, ids)
        if start_in_ids is None:
            continue  # try the other add_special setting

        # 3) Collect per-sentence token spans (in the coordinate system of `ids`)
        spans: List[Optional[List[int]]] = [None] * len(sents)

        for i, (a, b) in enumerate(offsets):
            # Some special tokens may have (0,0) or None-like offsets; skip
            if a is None or b is None or a == b:
                continue
            # Keep only tokens fully inside our chosen subtext window
            if a >= sub_char_start and b <= sub_char_end:
                # Midpoint to decide which sentence this token belongs to
                mid = (a + b - 1) // 2
                # Binary search over sentence spans
                lo, hi = 0, len(sent_spans) - 1
                sid = None
                while lo <= hi:
                    m = (lo + hi) // 2
                    sa, sb = sent_spans[m]
                    if sa <= mid < sb:
                        sid = m
                        break
                    if mid < sa:
                        hi = m - 1
                    else:
                        lo = m + 1

                if sid is not None:
                    pos = start_in_ids + i
                    if 0 <= pos < len(ids):
                        if spans[sid] is None:
                            spans[sid] = [pos, pos]
                        else:
                            spans[sid][1] = pos

        return [tuple(p) if p is not None else None for p in spans]

    # If we reach here, alignment failed under both settings; avoid false attribution.
    return [None] * (len(sents))

# helper functions for finding caption sentence chunk
def _sent_spans_via_find(text: str):
    from nltk.tokenize import sent_tokenize
    sents = sent_tokenize(text)
    spans = []
    cur = 0
    for s in sents:
        i = text.find(s, cur)
        if i == -1:
            i = cur
        spans.append((i, i + len(s)))
        cur = i + len(s)
    return sents, spans

def find_caption_chunk(
    *,
    source: str,                                  # source WITH the caption inserted (e.g., "... <caption> . ...")
    source_sentence_chunks: List[Optional[Tuple[int,int]]],
    caption: Optional[str] = None,                # exact caption text; if None, use first <...> span
    caption_occurrence: int = 0,                  # which <...> if there could be multiple (0-based)
) -> Tuple[int, Optional[Tuple[int,int]]]:
    """
    Returns (caption_sentence_index, caption_token_chunk).
    caption_token_chunk comes directly from source_sentence_chunks[caption_sentence_index],
    and may be None if that sentence was truncated.
    """
    # 1) sentence character spans in the (final) source
    _, sent_spans = _sent_spans_via_find(source)

    # 2) locate caption character span
    cap_start = cap_end = None
    if caption:
        # Prefer the caption inside angle brackets if present
        m = re.search(r"<\s*" + re.escape(caption) + r"\s*>", source)
        if m:
            cap_start, cap_end = m.start(), m.end()
        else:
            i = source.find(caption)
            if i != -1:
                cap_start, cap_end = i, i + len(caption)

    if cap_start is None:
        # fallback: pick the Nth <...> span (default first)
        matches = list(re.finditer(r"<[^>]+>", source))
        if not matches or caption_occurrence >= len(matches):
            raise ValueError("Caption span not found in source.")
        m = matches[caption_occurrence]
        cap_start, cap_end = m.start(), m.end()

    # 3) map caption chars -> sentence index (by midpoint; robust near boundaries)
    mid = (cap_start + cap_end - 1) // 2
    lo, hi = 0, len(sent_spans) - 1
    cap_sid = None
    while lo <= hi:
        k = (lo + hi) // 2
        sa, sb = sent_spans[k]
        if sa <= mid < sb:
            cap_sid = k
            break
        if mid < sa:
            hi = k - 1
        else:
            lo = k + 1
    if cap_sid is None:
        # very rare: choose sentence with max overlap
        best_i, best_ov = -1, -1
        for i, (sa, sb) in enumerate(sent_spans):
            ov = max(0, min(cap_end, sb) - max(cap_start, sa))
            if ov > best_ov:
                best_i, best_ov = i, ov
        cap_sid = best_i if best_i >= 0 else 0

    # 4) fetch the token chunk directly
    cap_chunk = source_sentence_chunks[cap_sid] if 0 <= cap_sid < len(source_sentence_chunks) else None
    return cap_sid, cap_chunk

# helper functions for token attribution
def get_token_mappings(attentions, input_len, image_mode, img_start, img_end, text_token_top_k):
    """
    Get the token mappings for the generated tokens to the source tokens.
    Returns a list of dictionaries with keys:
    - gen_token_idx: index of the generated token
    - src_token_idx: index of the source token
    - attention: attention score
    """

    if img_start is not None: # has image
        if image_mode == "raw":
            text_indices = list(range(0, img_start)) + list(range(img_end, input_len))
            img_indices = list(range(img_start, img_end))
    else:
        text_indices = list(range(input_len))
        img_indices = []

    token_mappings = []
    
    # For each generated token
    for t in range(len(attentions)):
        attn_per_layer = attentions[t]  # list of [batch, num_heads, q_len, k_len] for each layer
        attn_stack = torch.stack(attn_per_layer)
        attn_mean = attn_stack.mean(dim=0)     # remove layers dim -> [batch, num_heads, q_len, k_len]
        attn_mean = attn_mean.mean(dim=1)      # remove heads dim  -> [batch, q_len, k_len]
        attn_mean = attn_mean.mean(dim=1)      # remove q_len dim  -> [batch, k_len]
        attn_mean = attn_mean.squeeze(0)       # [k_len]
        # 1) record top text tokens
        text_info = []
        text_attn = attn_mean[text_indices]
        t_vals, t_pos = torch.topk(text_attn, k=text_token_top_k)
        t_idxs = [text_indices[i] for i in t_pos.tolist()]
        text_info = []
        for val, idx in zip(t_vals, t_idxs):
            sid = get_chunk_idx(source_sentence_chunks, idx)
            tok = processor.tokenizer.decode([inputs.input_ids[0, idx].item()], clean_up_tokenization_spaces=True)
            if tok in ['<|im_start|>', '<|im_end|>', '<|vision_start|>', '<|vision_end|>', '"<|image_pad|>', '\n']:
                continue # skip special text tokens
            text_info.append({
                "sentence_id": sid, # None means out of range, could be system prompt
                "token":       tok,
                "attention":   val.item(),
            })

        if image_mode == "raw":
            img_attention = attn_mean[img_indices].mean().item()
        else: 
            img_attention = None

        # Current generated token info:
        gen_token_id = generated_ids[0, input_len + t].item()
        # gen_token_text = processor.tokenizer.decode([gen_token_id], clean_up_tokenization_spaces=True)

        token_mappings.append({
            "gen_token":  gen_token_id,
            "text_hits":  text_info,
            "image_attention": img_attention
        })
    return token_mappings

def get_final_token_mappings(token_mappings, mode):
    # attribute generated text tokens to source text tokens based on mode
    final_token_mappings = []
    if mode == "max":
        for token_map in token_mappings:
            max_attention, max_attention_idx = -1, -1
            for i, attended in enumerate(token_map['text_hits']):
                if not attended['sentence_id']:
                    if attended['attention'] > max_attention:
                        max_attention = attended['attention']
                        max_attention_idx = i
            if max_attention_idx != -1:
                final_token_mappings.append({
                    'token': token_map['gen_token'],
                    'source_label': token_map['text_hits'][max_attention_idx]['sentence_id'],
                    # 'attention': max_attention
                })
            else: # out of range, e.g. system prompt
                final_token_mappings.append({
                    'token': token_map['gen_token'],
                    'source_label': None,
                    # 'attention': None
                })
    elif mode == "majority":
        for token_map in token_mappings:
            hits = token_map['text_hits']
            # count occurrences of each sentence_id
            counter = {}
            for h in hits:
                sid = h['sentence_id']
                if sid is None: continue
                counter[sid] = counter.get(sid, 0) + 1

            if not counter:
                # no valid sentence id
                final_token_mappings.append({
                    'token': token_map['gen_token'],
                    'source_label': None,
                    # 'attention': None
                })
                continue

            # find the sentence_id(s) with the maximum count
            max_count = max(counter.values())
            candidates = [sid for sid, cnt in counter.items() if cnt == max_count]

            if len(candidates) == 1:
                chosen_sid = candidates[0]
                attentions = [h['attention'] for h in hits if h['sentence_id'] == chosen_sid]
                # attention = sum(attentions) / len(attentions) if attentions else None
            else:
                # tie → pick candidate with the highest attention
                best_sid = None
                best_att = -1
                for h in hits:
                    if h['sentence_id'] in candidates and h['attention'] > best_att:
                        best_att  = h['attention']
                        best_sid  = h['sentence_id']
                chosen_sid = best_sid
                # attention = best_att

            final_token_mappings.append({
                'token': token_map['gen_token'],
                'source_label': chosen_sid,
                # 'attention': attention
            })
    return final_token_mappings

test_split = "test_mm"
text_token_top_k_list = [3]  # Number of top attended source tokens to consider
token_to_token_attr_mode_list = ["majority"] # "majority"
aggregation_threshold_list = [0.2]

image_mode = "caption" # "raw", "caption", "none"

model_checkpoint = "Qwen2.5-VL-7B-Instruct"
config = Qwen2_5_VLConfig.from_pretrained(model_checkpoint)

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_checkpoint, 
    config=config, 
    torch_dtype="auto", 
    device_map="auto",
    attn_implementation="eager",
)
model.config.output_attentions = True # always output attentions
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(model_checkpoint, 
                                        min_pixels=min_pixels, 
                                        max_pixels=max_pixels,
                                        use_fast=True)

with open(f"{test_split}.json", 'r') as f:
    test_data = json.load(f)

alignment_error = 0
generation_outputs, attribution_outputs = [], {}
for _, data in tqdm(enumerate(test_data), total=len(test_data)):
    # prepare input message
    source = data["source"]
    if "image" in data: # mm, from cliconsummation
        # check if image file actually exists
        imagt_path = f"CliConSummation Dataset/image_data/train/{data['image']}"
        try:
            with open(imagt_path, 'rb') as f:
                pass
        except:
            print(f"Image file {imagt_path} not found, skip this sample.")
            continue
        has_image = True
        if image_mode == "raw":
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image",
                        "image": f"CliConSummation Dataset/image_data/train/{data['image']}",
                        },
                        {"type": "text", 
                        "text": "Summarize the following patient-doctor dialogue. Use the provided image when forming the summary.\n\nConversation:\n" + source},
                    ],
                }
            ]
        else:
            # generated caption for the image
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image",
                        "image": f"CliConSummation Dataset/image_data/train/{data['image']}",
                        },
                        {"type": "text", 
                        "text": "Describe the image using one sentence, starting with 'An image showing'."},
                    ],
                }
            ]
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            outputs = model.generate(**inputs, max_new_tokens=128)
            generated_ids = outputs.sequences
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = processor.batch_decode(
                generated_ids_trimmed, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )
            output_text = output_text[0]
            caption_text = output_text.strip().replace('\n', ' ')
            # prepare new messages with caption
            assert '<image>' in source, "Image placeholder <image> not found in source text."
            source = source.replace('<image>', f"<{caption_text}>")
            prompt = "Summarize the following patient-doctor dialogue.\n\nConversation:\n" + source
            messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", 
                            "text": prompt},
                        ],
                    }
                ]
    else: # text-only
        has_image = False
        if data["src"] == "cliconsummation":
            prompt = "Summarize the following patient-doctor dialogue.\n\nConversation:\n" + source
        else: 
            prompt = "Summarize the following FINDINGS section of a chest X-ray report and produce a concise clinical IMPRESSION that summarizes the most relevant abnormalities (if any).\n\nFINDINGS:\n" + source
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", 
                    "text": prompt},
                ],
            }
        ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Generate output and get attentions
    outputs = model.generate(**inputs, 
                            max_new_tokens=128,
                            do_sample=False, # <--- Deterministic, greedy decoding
                            temperature=1.0)
    
    attentions = outputs.attentions # [num_gen_token, [num_layers[28], (batch_num, num_heads[28], seq_len[1 after the first], curr_seq_len) ]]
    generated_ids = outputs.sequences
    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=False
    )
    output_text = output_text[0]
    ids = inputs.input_ids[0]

    # calcuate span of image token
    if has_image and image_mode == "raw":
        vis_mask = []
        for x in processor.tokenizer.convert_ids_to_tokens(inputs.input_ids[0]):
            if x in ["<|image_pad|>", "<|vision_start|>", "<|vision_end|>"]:
                vis_mask.append(True)
            else:
                vis_mask.append(False)
        img_len = sum(vis_mask)
        img_start = vis_mask.index(True) if True in vis_mask else -1
        img_end = img_start + img_len if img_start != -1 else -1
        text = text.replace("<|vision_start|><|image_pad|><|vision_end|>", "")
        ids = torch.cat([ids[:img_start], ids[img_end:]], dim=0)
    else:
        img_start = None
        img_end = None
    
    # Attribute input tokens to sentence IDs for the source text 
    source_sentences = sent_tokenize(source)
    source_sentence_chunks = compute_sentence_chunks(
        tokenizer=processor.tokenizer,
        base_text=text,
        ids=ids,
        subtext=source,
    )
    
    if has_image:
        if image_mode == "raw": # shift the chunk index after image token
            source_sentence_chunks_mm = []
            for chunk in source_sentence_chunks:
                if not chunk:
                    source_sentence_chunks_mm.append(None)
                    continue
                start, end = chunk
                if start > img_start:
                    source_sentence_chunks_mm.append((start + img_len, end + img_len))
                else: 
                    source_sentence_chunks_mm.append((start, end))
            source_sentence_chunks = source_sentence_chunks_mm
        else:
            cap_sid, cap_chunk = find_caption_chunk(
                source=source,                       # source after replacing <image> with <caption>
                source_sentence_chunks=source_sentence_chunks,
                caption=caption_text,                # or None to pick the first <...> span
            )

    generated_sentences = sent_tokenize(output_text)
    generated_sentence_chunks = compute_sentence_chunks(
        tokenizer=processor.tokenizer,
        base_text=output_text,
        ids=generated_ids_trimmed[0],
        subtext=None,
    )

    if has_image and image_mode == "caption":
        generation_outputs.append({
            "i" : data['i'],
            "image" : data.get('image', None),
            "image_sid": cap_sid,
            "model_summary" : output_text,
            "sent_source" : source_sentences,
            "sent_summary" : generated_sentences
        })
    else:
        generation_outputs.append({
            "i" : data['i'],
            "image" : data.get('image', None),
            "model_summary" : output_text,
            "sent_source" : source_sentences,
            "sent_summary" : generated_sentences
        })

    # Hyperparameter
    input_len = inputs.input_ids.shape[1]  # Number of input tokens (text + image)
    for text_token_top_k in text_token_top_k_list:
        for token_to_token_attr_mode in token_to_token_attr_mode_list:
            for agg_thres in aggregation_threshold_list:
                sentence_to_sources = []
                if all(x is None for x in generated_sentence_chunks): # alignment error
                    sentence_to_sources.append([])
                    alignment_error += 1
                else:
                    # Get gen_token to top_k src_token & image mappings
                    token_mappings = get_token_mappings(attentions, input_len, image_mode, img_start, img_end, text_token_top_k)
                    image_attentions = [tm['image_attention'] for tm in token_mappings]
                    final_token_mappings = get_final_token_mappings(token_mappings, token_to_token_attr_mode)
                    avg_img_attn_per_sent = []
                    for span in generated_sentence_chunks:
                        if not span:
                            sentence_to_sources.append([])
                            continue
                        start, end = span
                        labels = [final_token_mappings[i]['source_label'] for i in range(start, end + 1)]
                        valid_labels = [a for a in labels if a is not None]
                        effective_len = len(valid_labels)
                        if effective_len == 0:
                            sentence_to_sources.append([])
                            continue
                        thr = math.ceil(agg_thres * effective_len)
                        counts = Counter(valid_labels)
                        majors = [src for src, cnt in counts.items() if cnt >= thr]
                        sentence_to_sources.append(majors)
                        if has_image and image_mode == "raw":
                            img_attns = image_attentions[start : end + 1]
                            avg_img_attn_per_sent.append(sum(img_attns) / len(img_attns) if img_attns else 0.0)
                    if has_image:
                        if image_mode == "raw":
                            sent_id_with_img = np.argmax(avg_img_attn_per_sent)
                            sentence_to_sources[sent_id_with_img].append('IMG')
                        else:
                            items = [] # replace caption sentence id with 'IMG'
                            for source in sentence_to_sources:
                                for i in range(len(source)):
                                    if source[i] == cap_sid:
                                        source[i] = 'IMG' 
                                items.append(source)
                            sentence_to_sources = items

                if data['i'] not in attribution_outputs:
                    attribution_outputs[data['i']] = {}
                    attribution_outputs[data['i']][f"top_{text_token_top_k}_{token_to_token_attr_mode}_{agg_thres}"] = sentence_to_sources
                else: attribution_outputs[data['i']][f"top_{text_token_top_k}_{token_to_token_attr_mode}_{agg_thres}"] = sentence_to_sources

print(f"{alignment_error / (len(test_data) * len(text_token_top_k_list) * len(token_to_token_attr_mode_list) * len(aggregation_threshold_list))} alignment errors")

with open(f"responses/qwen/{test_split}_generation_{image_mode}.json", 'w') as f:
    json.dump(generation_outputs, f, indent=4)

with open(f"responses/qwen/{test_split}_attribution_{image_mode}.json", 'w') as f:
    json.dump(attribution_outputs, f, indent=4)

