from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field

import torch

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool

from PIL import Image

from transformers import (
    BertTokenizer,
    ViTImageProcessor,
    VisionEncoderDecoderModel,
    GenerationConfig,
)


class ChestXRayInput(BaseModel):
    """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""

    image_path: str = Field(
        ..., description="Path to the radiology image file, only supports JPG or PNG images"
    )


class ChestXRayReportGeneratorTool(BaseTool):
    """Tool that generates comprehensive chest X-ray reports with both findings and impressions.

    This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert
    and MIMIC-CXR datasets to generate structured radiology reports. It automatically
    generates both detailed findings and impression summaries for each chest X-ray,
    following standard radiological reporting format.

    The tool uses:
    - Findings model: Generates detailed observations of all visible structures
    - Impression model: Provides concise clinical interpretation and key diagnoses
    """

    name: str = "chest_xray_report_generator"
    description: str = (
        "A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
        "containing both detailed findings and impression summaries. Input should be the path "
        "to a chest X-ray image file. Output is a structured report with both detailed "
        "observations and key clinical conclusions."
    )
    device: Optional[str] = "cuda"
    args_schema: Type[BaseModel] = ChestXRayInput
    findings_model: VisionEncoderDecoderModel = None
    impression_model: VisionEncoderDecoderModel = None
    findings_tokenizer: BertTokenizer = None
    impression_tokenizer: BertTokenizer = None
    findings_processor: ViTImageProcessor = None
    impression_processor: ViTImageProcessor = None
    generation_args: Dict[str, Any] = None

    def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"):
        """Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
        super().__init__()
        self.device = torch.device(device) if device else "cuda"

        # Initialize findings model
        self.findings_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        ).eval()
        self.findings_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )
        self.findings_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )

        # Initialize impression model
        self.impression_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        ).eval()
        self.impression_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )
        self.impression_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )

        # Move models to device
        self.findings_model = self.findings_model.to(self.device)
        self.impression_model = self.impression_model.to(self.device)

        # Default generation arguments
        self.generation_args = {
            "num_return_sequences": 1,
            "max_length": 128,
            "use_cache": True,
            "beam_width": 2,
        }

    def _process_image(
        self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
    ) -> torch.Tensor:
        """Process the input image for a specific model.

        Args:
            image_path (str): Path to the input image.
            processor: Image processor for the specific model.
            model: The model to process the image for.

        Returns:
            torch.Tensor: Processed image tensor ready for model input.
        """
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values

        expected_size = model.config.encoder.image_size
        actual_size = pixel_values.shape[-1]

        if expected_size != actual_size:
            pixel_values = torch.nn.functional.interpolate(
                pixel_values,
                size=(expected_size, expected_size),
                mode="bilinear",
                align_corners=False,
            )

        pixel_values = pixel_values.to(self.device)

        return pixel_values

    def _generate_report_section(
        self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
    ) -> str:
        """Generate a report section using the specified model.

        Args:
            pixel_values: Processed image tensor.
            model: The model to use for generation.
            tokenizer: The tokenizer for the model.

        Returns:
            str: Generated text for the report section.
        """
        generation_config = GenerationConfig(
            **{
                **self.generation_args,
                "bos_token_id": model.config.bos_token_id,
                "eos_token_id": model.config.eos_token_id,
                "pad_token_id": model.config.pad_token_id,
                "decoder_start_token_id": tokenizer.cls_token_id,
            }
        )

        generated_ids = model.generate(pixel_values, generation_config=generation_config)

        return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def _run(
        self,
        image_path: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        """Generate a comprehensive chest X-ray report containing both findings and impression.

        Args:
            image_path (str): The path to the chest X-ray image file.
            run_manager (Optional[CallbackManagerForToolRun]): The callback manager.

        Returns:
            Tuple[str, Dict]: A tuple containing the complete report and metadata.
        """
        try:
            # Process image for both models
            findings_pixels = self._process_image(
                image_path, self.findings_processor, self.findings_model
            )
            impression_pixels = self._process_image(
                image_path, self.impression_processor, self.impression_model
            )

            # Generate both sections
            with torch.inference_mode():
                findings_text = self._generate_report_section(
                    findings_pixels, self.findings_model, self.findings_tokenizer
                )
                impression_text = self._generate_report_section(
                    impression_pixels, self.impression_model, self.impression_tokenizer
                )

            # Combine into formatted report
            report = (
                "CHEST X-RAY REPORT\n\n"
                f"FINDINGS:\n{findings_text}\n\n"
                f"IMPRESSION:\n{impression_text}"
            )

            metadata = {
                "image_path": image_path,
                "analysis_status": "completed",
                "sections_generated": ["findings", "impression"],
            }

            return report, metadata

        except Exception as e:
            return f"Error generating report: {str(e)}", {
                "image_path": image_path,
                "analysis_status": "failed",
                "error": str(e),
            }

    async def _arun(
        self,
        image_path: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        """Asynchronously generate a comprehensive chest X-ray report."""
        return self._run(image_path)
