"""
Tamper Detection Task Handler

This module provides a task handler for the tamper detection task.
"""

import copy
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Tuple

import pandas as pd
from tqdm import tqdm

from src.core.registry import task_registry
from src.llm.base import LLMInterface
from src.tasks.base import TaskInterface
from src.tasks.tamper_detection.data_handler import TamperDetectionDataHandler
from src.tasks.tamper_detection.eval_handler import TamperDetectionEvaluator
from src.utils.decorator_utils import with_logger
from src.utils.error_tracking import ErrorTracker


@task_registry.register("tamper_detection")
class TamperDetectionTask(TaskInterface):
    """
    Task handler for the tamper detection task.

    This class implements the task interface for the tamper detection task.
    """

    @with_logger
    def __init__(
        self,
        font_semantics: str,
        prompt_msg_template: List[Dict[str, Any]],
        num_images: int = 25,
        **kwargs: Any,
    ):
        """
        Initialise the tamper detection task handler.

        Args:
            font_semantics: Whether to use "font" or "semantics" data
            prompt_msg_template: The prompt message template to use
            num_images: The number of images to process
            **kwargs: Additional keyword arguments
        """
        logger.info(
            f"Initialising TamperDetectionTask with font_semantics={font_semantics}, num_images={num_images}"
        )
        self.font_semantics = font_semantics
        self.num_images = num_images

        logger.info("Using provided prompt template")
        self.prompt_msg_template = prompt_msg_template

        # initialise the data handler and evaluator
        logger.info("Initialising data handler and evaluator")
        self.data_handler = TamperDetectionDataHandler(font_semantics)
        self.eval_handler = TamperDetectionEvaluator()
        logger.info("TamperDetectionTask initialisation complete")

    @with_logger
    def load_data(self, **kwargs: Any) -> List[int]:
        """
        Load task-specific data.

        Args:
            **kwargs: Additional keyword arguments for data loading

        Returns:
            A list of data IDs
        """
        logger.info("Loading task data")
        dataset_size = self.data_handler.get_size()
        num_images_to_use = min(self.num_images, dataset_size)
        logger.info(f"Using {num_images_to_use} images out of {dataset_size} available")

        # Return a list of data IDs up to num_images
        data_ids = list(range(num_images_to_use))
        logger.info(f"Data IDs: {data_ids}")
        return data_ids

    @with_logger
    def create_prompt(
        self,
        data_item: Tuple[str, str],
        **kwargs: Any,
    ) -> List[Dict[str, Any]]:
        """
        Create a prompt for a specific data item.

        Args:
            data_item: The data item to create a prompt for, as (image_base64, image_type)
            **kwargs: Additional keyword arguments for prompt creation

        Returns:
            A list of message dictionaries
        """
        logger.debug("Creating prompt for data item")
        img_base64, img_type = data_item

        # Create a deep copy of the prompt template
        prompt_msgs = copy.deepcopy(self.prompt_msg_template)

        # Inject the image data into the prompt
        for msg in prompt_msgs:
            if msg["role"] == "user":
                for content in msg["content"]:
                    if content["type"] == "image_url":
                        content["image_url"]["url"] = (
                            f"data:image/{img_type};base64,{img_base64}"
                        )
                        logger.debug(f"Injected image of type {img_type} into prompt")

        return prompt_msgs

    @with_logger
    def run(
        self,
        llm: LLMInterface,
        **kwargs: Any,
    ) -> Tuple[pd.DataFrame, float]:
        """
        Run the task with the given LLM and prompt optimiser.

        Args:
            llm: The language model to use
            **kwargs: Additional keyword arguments for task execution

        Returns:
            A tuple of (results_dataframe, evaluation_score)
        """
        logger.info(f"Running tamper detection task with LLM: {llm.__class__.__name__}")
        logger.info(self.prompt_msg_template)

        # Load data IDs
        data_ids = self.load_data()

        # initialise result lists
        response_original_lst = [None] * len(data_ids)
        response_tampered_lst = [None] * len(data_ids)

        # Define a function to process a single image pair
        @with_logger
        def process_image(img_id: int) -> Tuple[int, str, str]:
            # Get the original and tampered images and their types
            logger.debug(f"Processing image pair with ID: {img_id}")
            img_original, type_original, img_tampered, type_tampered = (
                self.data_handler.get_data(img_id)
            )

            # Create prompts for original and tampered images
            logger.debug(f"Creating prompts for image ID: {img_id}")
            original_prompt = self.create_prompt((img_original, type_original))
            tampered_prompt = self.create_prompt((img_tampered, type_tampered))

            # Generate responses
            logger.debug(f"Generating LLM response for original image {img_id}")
            response_original = llm.generate(original_prompt)
            logger.debug(f"Generating LLM response for tampered image {img_id}")
            response_tampered = llm.generate(tampered_prompt)

            # Clean up responses
            response_original = response_original.replace(".", "").strip()
            response_tampered = response_tampered.replace(".", "").strip()

            logger.debug(
                f"Image {img_id} processed: original='{response_original}', tampered='{response_tampered}'"
            )
            return img_id, response_original, response_tampered

        # Use ThreadPoolExecutor for parallelization
        max_workers = min(32, os.cpu_count() * 5) if os.cpu_count() else 32
        logger.info(f"Using ThreadPoolExecutor with {max_workers} workers")
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            logger.info(f"Submitting {len(data_ids)} image pairs for processing")
            futures = [executor.submit(process_image, i) for i in data_ids]

            # Process results as they complete
            with tqdm(total=len(data_ids) * 2, desc="Processing Images") as pbar:
                for future in as_completed(futures):
                    img_id, response_original, response_tampered = future.result()
                    response_original_lst[img_id] = response_original
                    response_tampered_lst[img_id] = response_tampered
                    pbar.update(2)  # Update by 2 since we processed 2 images
                    logger.debug(f"Completed processing image ID: {img_id}")

        # Create a results dataframe
        logger.info("Creating results dataframe")
        result_df = pd.DataFrame(
            {
                "img_id": data_ids,
                "response_original": response_original_lst,
                "response_tampered": response_tampered_lst,
            }
        )

        # Log response distribution
        yes_original = sum(1 for r in response_original_lst if r == "Yes")
        no_original = sum(1 for r in response_original_lst if r == "No")
        yes_tampered = sum(1 for r in response_tampered_lst if r == "Yes")
        no_tampered = sum(1 for r in response_tampered_lst if r == "No")
        logger.info(
            f"Response distribution - Original: Yes={yes_original}, No={no_original} | Tampered: Yes={yes_tampered}, No={no_tampered}"
        )

        # Evaluate the results
        logger.info("Evaluating results")
        eval_score = self.evaluate(result_df)
        logger.info(f"Evaluation complete. Score: {eval_score:.4f}")

        return result_df, eval_score, ErrorTracker()

    @with_logger
    def evaluate(
        self,
        results: pd.DataFrame,
        **kwargs: Any,
    ) -> float:
        """
        Evaluate the results of the task.

        Args:
            results: The results to evaluate
            **kwargs: Additional keyword arguments for evaluation

        Returns:
            The evaluation score
        """
        logger.info(f"Evaluating results dataframe with {len(results)} entries")
        # Use the mean score as the primary evaluation metric
        score = self.eval_handler.get_eval_score(results)
        logger.info(f"Evaluation score: {score:.4f}")
        return score

    @with_logger
    def get_prompt_msg_template(self) -> List[Dict[str, Any]]:
        """
        Get the prompt message template.

        Returns:
            The prompt message template
        """
        logger.info("Getting prompt message template")
        return copy.deepcopy(self.prompt_msg_template)

    @with_logger
    def update_prompt_msg_template(
        self, new_prompt_msg_template: List[Dict[str, Any]]
    ) -> None:
        """
        Update the prompt message template.

        Args:
            new_prompt_msg_template: The new prompt message template
        """
        logger.info("Updating prompt message template")
        logger.info(
            f"New system prompt: {new_prompt_msg_template[0]['content'][:50]}..."
        )
        self.prompt_msg_template = copy.deepcopy(new_prompt_msg_template)
