import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import numpy as np
import requests

from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from libra.conversation import conv_templates, SeparatorStyle
from libra.model.builder import load_pretrained_model
from libra.utils import disable_torch_init
from libra.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria

import math
import pydicom
from typing import Dict
from PIL import Image
from io import BytesIO
from pydicom.pixel_data_handlers.util import apply_voi_lut

from transformers.generation.logits_process import (
    LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
)

import torch.nn.functional as F

import torchvision
import skimage.io
import torchxrayvision as xrv

_CHEXPERT = None
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _get_chexpert_models():
    global _CHEXPERT
    if _CHEXPERT is None:
        _CHEXPERT = xrv.models.DenseNet(weights="densenet121-res224-chex").to(_DEVICE).eval()
    return _CHEXPERT
    
def predict_chexpert_labels(image_file, image_folder) -> dict:

    image_path = os.path.join(image_folder, image_file)
    img = skimage.io.imread(image_path)
    img = xrv.datasets.normalize(img, 255)  

    if img.ndim == 2:
        img = img[None, ...]
    elif img.ndim == 3 and img.shape[2] == 3:
        img = img.mean(2)[None, ...]


    transform = torchvision.transforms.Compose([
        xrv.datasets.XRayCenterCrop(),
        xrv.datasets.XRayResizer(224),
    ])
    img = transform(img)


    img_tensor = torch.from_numpy(img).float()  # [1,H,W]
    img_tensor = img_tensor.unsqueeze(0).to(_DEVICE)  # [B=1, 1, H, W]

    chex = _get_chexpert_models()
    with torch.inference_mode():
        logits_path = chex(img_tensor)                 # [1, C]
        scores = logits_path[0].detach().cpu().numpy()
        image_labels = chex.pathologies

    label_score_dict = {label: float(score) for label, score in zip(image_labels, scores) if label}

    return label_score_dict
    
def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

def load_images(image_file):
    """
    Load an image from a local file, a URL, or a DICOM file.

    Args:
        image_file (str): The path or URL of the image file to load.

    Returns:
        PIL.Image.Image: The loaded image in RGB format.

    Raises:
        ValueError: If the DICOM file does not contain image data.
        TypeError: If the input is neither a valid file path nor a URL.
    """
    if isinstance(image_file, str):
        # Case 1: Load from URL
        if image_file.startswith(('http://', 'https://')):
            try:
                response = requests.get(image_file)
                response.raise_for_status()
                image = Image.open(BytesIO(response.content)).convert('RGB')
            except Exception as e:
                raise ValueError(f"Error loading image from URL: {image_file}\n{e}")

        # Case 2: Load from DICOM file
        elif image_file.lower().endswith('.dcm'):
            try:
                dicom = pydicom.dcmread(image_file)
                if 'PixelData' in dicom:
                    data = apply_voi_lut(dicom.pixel_array, dicom)

                    # Handle MONOCHROME1 images
                    if dicom.PhotometricInterpretation == "MONOCHROME1":
                        data = np.max(data) - data

                    # Normalize the image data
                    data = data - np.min(data)
                    data = data / np.max(data)
                    data = (data * 255).astype(np.uint8)

                    # Convert to 3-channel RGB if necessary
                    if data.ndim == 2:
                        data = np.stack([data] * 3, axis=-1)

                    image = Image.fromarray(data).convert('RGB')
                else:
                    raise ValueError("DICOM file does not contain image data")
            except Exception as e:
                raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")

        # Case 3: Load standard image files (e.g., PNG, JPG)
        else:
            try:
                image = Image.open(image_file).convert('RGB')
            except Exception as e:
                raise ValueError(f"Error loading standard image file: {image_file}\n{e}")

    else:
        raise TypeError("image_file must be a string representing a file path or URL")

    return image

def get_image_tensors(image_file, image_folder, image_processor, model, device='cuda'):
    # Load and preprocess the images
    if isinstance(image_file, str):
        image = []
        image_path = os.path.join(image_folder, image_file)
        img = load_images(image_path)
        image.append(img)
    elif isinstance(image_file, (list, tuple)):
        image = []
        image_paths = [os.path.join(image_folder, file_name) for file_name in image_file]
        for path in image_paths:
            img = load_images(path)
            image.append(img)
    else:
        raise TypeError("image_file must be a string or a str/list/tuple of strings")

    if len(image) != 2:
        image.append(image[0])  
    processed_images = []
    for img_data in image:
        image_temp = process_images([img_data], image_processor, model.config)[0]
        image_temp = image_temp.to(device=device, non_blocking=True)
        processed_images.append(image_temp)

    # Separate current and prior images
    cur_images = [processed_images[0]]
    prior_images = [processed_images[1]]

    # Stack and return as batched tensor
    batch_images = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])

    return batch_images


def get_chexpert_token_score_dict(
    label_score_dict: Dict[str, float],
    tokenizer,
    threshold: float = 0.0,
    only_first_token: bool = True,
    add_leading_space: bool = False,
    decay_gamma: float = 0.95,  
) -> Dict[int, float]:
    token_score_dict: Dict[int, float] = {}

    for label, score in label_score_dict.items():
        base_score = score if score >= threshold else -score

        for variant_label in {label, label.lower()}:
            text_label = variant_label.replace("_", " ")

            if add_leading_space:
                text_label = " " + text_label

            tokens = tokenizer.tokenize(text_label)
            token_ids = tokenizer.convert_tokens_to_ids(tokens)
            if not token_ids:
                continue
            prev_was_space = True
            score_i = base_score  
            for i, (tok, tid) in enumerate(zip(tokens, token_ids)):
                include = (tok.startswith("Ġ") or tok.startswith("▁") or prev_was_space) if only_first_token else True

                if include:
                    if tid not in token_score_dict or abs(score_i) > abs(token_score_dict[tid]):
                        token_score_dict[tid] = float(score_i)
                score_i *= decay_gamma
                prev_was_space = tok in {"Ġ", "▁"}

    return token_score_dict

def clinical_guide_generation(label_score_dict, threshold=0.5):
    clinical_labels = []

    for label, score in label_score_dict.items():
        if score >= threshold:
            clinical_labels.append(label)

    if not clinical_labels:
        return "No clinical guidance based on current findings."

    clinical_text = ""
    if clinical_labels:
        clinical_text = "Attention to the following clinical instructions: " + "; ".join(clinical_labels) + "."

    return clinical_text + " "

def ccd(label_score_dict, model, tokenizer, image_tensors, noisy_image_tensors,
                       input_ids, dist_input_ids, attention_mask, dist_attention_mask,
                       keywords, stopping_criteria, alpha=0.5, max_tokens=256, device="cuda",
                       stop_on_keywords=True, return_all=False, do_sample=False, 
                       temperature=None, top_k=None, top_p=None, 
                       boost_gamma=10.0, beta=0.5, mode="logit"):

    """Visual Contrastive Decoding with CheXpert label guidance."""
    token_score_dict = get_chexpert_token_score_dict(
        label_score_dict=label_score_dict,
        tokenizer=tokenizer,
        threshold=0.0,
        only_first_token=True,
        add_leading_space=False
    )
    
    generated_ids = input_ids.clone()
    dist_generated_ids = dist_input_ids.clone()

    logits_processor = LogitsProcessorList()
    logits_warper = LogitsProcessorList()
    
    if temperature and temperature != 1.0:
        logits_warper.append(TemperatureLogitsWarper(temperature))
    if top_k and top_k > 0:
        logits_warper.append(TopKLogitsWarper(top_k))
    if top_p and top_p < 1.0:
        logits_warper.append(TopPLogitsWarper(top_p))

    for _ in range(max_tokens):
        with torch.inference_mode():
            logits_orig = model(input_ids=generated_ids, images=image_tensors, attention_mask=torch.ones_like(generated_ids)).logits[:, -1, :]
            logits_clinical = model(input_ids=dist_generated_ids, images=image_tensors, attention_mask=torch.ones_like(dist_generated_ids)).logits[:, -1, :]


        logits_orig = F.log_softmax(logits_orig, dim=-1)
        logits_clinical = F.log_softmax(logits_clinical, dim=-1)
        logits_soft_guide = (1 - alpha) * logits_orig + alpha * logits_clinical

        logits_chexpert = logits_soft_guide.clone() 

        max_bias = math.log(boost_gamma)
        for tid, score in token_score_dict.items():
            if mode == "logit":
                p = max(min(float(score), 1 - 1e-6), 1e-6)  
                bias =  math.log(p / (1.0 - p))  
            else:
                raise ValueError("mode must be 'logit'")
            if boost_gamma > 1:
                bias = max(-max_bias, min(max_bias, bias))
            logits_chexpert[:, tid] = logits_chexpert[:, tid] + bias


        logits_soft_guide = logits_processor(generated_ids, logits_soft_guide)
        logits_soft_guide = logits_warper(generated_ids, logits_soft_guide)

        logits_final = (1 - beta) * logits_soft_guide + beta * logits_chexpert

        if do_sample:
            probs = F.softmax(logits_final, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)
        else:
            next_token_id = torch.argmax(logits_final, dim=-1, keepdim=True)


        next_token_text = tokenizer.decode(next_token_id[0])

        if stop_on_keywords and any(k in next_token_text for k in keywords):
            break

        generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
        dist_generated_ids = torch.cat([dist_generated_ids, next_token_id], dim=1)


    ccd_output = tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)

    return "", ccd_output
    
    
def eval_model(args):
    """
    Evaluate a pre-trained model on a set of questions and images.
    Args:
        args (Namespace): A namespace object containing the following attributes:
            - model_path (str): Path to the pre-trained model.
            - model_base (str): Base model name.
            - question_file (str): Path to the JSON file containing questions.
            - num_chunks (int): Number of chunks to split the questions into.
            - chunk_idx (int): Index of the chunk to process.
            - answers_file (str): Path to the file where answers will be saved.
            - image_folder (str): Folder containing the images.
            - conv_mode (str): Conversation mode to use.
            - temperature (float): Sampling temperature for generation.
            - top_p (float): Top-p sampling parameter.
            - num_beams (int): Number of beams for beam search.
            - max_new_tokens (int): Maximum number of new tokens to generate.
            - length_penalty (float): Length penalty for beam search.
            - num_return_sequences (int): Number of sequences to return.
    Raises:
        TypeError: If `image_file` is neither a string nor a list/tuple of strings.
    Returns:
        None
    """
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
    
    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    for line in tqdm(questions):
        idx = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        cur_prompt = qs
        default_prompt = qs

        if model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        label_score_dict = predict_chexpert_labels(image_file[0], args.image_folder)
        clinical_guide = clinical_guide_generation(label_score_dict, threshold=0.5)
        if clinical_guide is not None:
            dist_qs = qs + " " + clinical_guide
        else:
            dist_qs = qs

        conv_dist = conv_templates[args.conv_mode].copy()
        conv_dist.append_message(conv_dist.roles[0], dist_qs)
        conv_dist.append_message(conv_dist.roles[1], None)
        dist_prompt = conv_dist.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        dist_input_ids = tokenizer_image_token(dist_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
        dist_attention_mask = torch.ones(dist_input_ids.shape, dtype=torch.long)

        pad_token_id = tokenizer.pad_token_id
        image_tensors = get_image_tensors(image_file[0], args.image_folder, image_processor, model)
        stop_str = conv.sep if conv.sep_style not in {SeparatorStyle.TWO, SeparatorStyle.LLAMA_3, SeparatorStyle.MISTRAL} else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        
        with torch.inference_mode():
            torch.cuda.empty_cache()
            _, vcd_output = ccd(
                label_score_dict=label_score_dict,
                model=model,
                image_tensors=image_tensors,
                noisy_image_tensors=None,
                input_ids=input_ids,
                dist_input_ids=dist_input_ids,
                attention_mask=attention_mask,
                dist_attention_mask=dist_attention_mask,
                tokenizer=tokenizer,
                stopping_criteria=stopping_criteria,
                alpha=args.alpha,
                keywords=keywords,
                max_tokens=args.max_new_tokens,
                device="cuda",
                stop_on_keywords=True,
                return_all=False,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_k=None,
                top_p=args.top_p if args.top_p else 1.0,
                boost_gamma = args.boost_gamma,
                beta = args.beta
            )
                    
        torch.cuda.empty_cache() 

        
        outputs = vcd_output.strip()
        ans_id = shortuuid.uuid()
        ans_file.write(json.dumps({"question_id": idx,
                                   "prompt": cur_prompt,
                                   "text": outputs,
                                   "answer_id": ans_id,
                                   "model_id": model_name,
                                   "metadata": {}}) + "\n")
        ans_file.flush()
    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="LLaVA-Med")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="v0")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--num_return_sequences", type=int, default=None)
    parser.add_argument("--length_penalty", type=float, default=1.0)
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--alpha", type=float, default=0.5, 
                       help="alpha for first stage between the original model and the expert model")
    parser.add_argument("--boost-gamma", type=float, default=10.0, 
                       help="gamma for boosting the expert model")
    parser.add_argument("--beta", type=float, default=0.5,
                       help="beta for second stage the balance between the original and expert models")
    args = parser.parse_args()

    eval_model(args)