import logging
import shutil
import tempfile
import uuid
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union

from ..llm.prompt_builder import PromptBuilder
from .agents_ensemble import AgentsEnsemble
from .round_n import run_debate_round_n
from .round_zero import run_debate_round_zero

logger = logging.getLogger(__name__)


def debate(
    max_rounds: int,
    prompt_builder: PromptBuilder,
    agents_ensemble: AgentsEnsemble,
    extract_func: Callable,
    output_dir: Optional[Union[str, Path]] = None,
    json_mode: bool = False,
    max_retries: int = 3,
    temperature: float = 1.0,
    max_tokens: int = 6400,
    batch: bool = False,
    batch_size: int = 11,
    quality_pruning_func: Callable = None,
    quality_pruning_amount: int = 5,
    diversity_pruning_func: Callable = None,
    diversity_pruning_amount: int = 5,
) -> List[List[dict]]:
    """Run a full debate with multiple rounds using the given prompts and agents.

    Coordinates multiple rounds of debate between agents, starting with round zero
    and continuing through subsequent rounds. Logs progress and saves results.
    Files are only saved if the debate completes successfully. If interrupted,
    all generated files from this debate are deleted.

    Args:
        max_rounds: Maximum number of debate rounds to run.
        prompt_builder: PromptBuilder instance to generate prompts for each round.
        agents_ensemble: Collection of LLM agents participating in the debate.
        output_dir: Directory path where debate responses will be saved.
        json_mode: Whether to use JSON mode for responses.
        extract_func: Function to process answers from responses.
        max_retries: Maximum retry attempts for each round. Defaults to 3.
        temperature: Sampling temperature for the model. Defaults to 1.0.
        max_tokens: Maximum number of tokens in the response. Defaults to 6400.
        batch: Whether to run in batch mode.
        batch_size: Size of the batch.
        quality_pruning_func: Optional function for quality pruning.
        quality_pruning_amount: int = 5,
        diversity_pruning_func: Optional function for diversity pruning.
        diversity_pruning_amount: int = 5,

    Returns:
        List[List[dict]]: List of responses from each round, where each round's
            responses is a list of dictionaries containing agent responses.

    Raises:
        Exception: If any error occurs during the debate process.
            Original exception is logged and re-raised.
    """
    # If extract_func is None, use extract_bool_answer as default
    if extract_func is None:
        logger.error("No extract_func function provided")
        raise ValueError("extract_func function must be provided")

    logger.info(f"Starting debate with max_rounds={max_rounds}, json_mode={json_mode}")
    logger.info(f"Using agents ensemble: {agents_ensemble}")

    # Create a temporary directory for intermediate files
    if output_dir is None:
        output_dir = Path(tempfile.gettempdir())
    else:
        output_dir = Path(output_dir)

    output_dir.mkdir(parents=True, exist_ok=True)
    temp_dir = Path(
        tempfile.mkdtemp(prefix=f"debate_temp_{uuid.uuid4().hex}_", dir=output_dir)
    )
    logger.debug(f"Created temporary directory for debate: {temp_dir}")

    all_responses = []
    images = prompt_builder.images  # This can now contain paths, URLs, or byte data

    try:
        for i in range(max_rounds):
            if i == 0:
                logger.info("Running round 0 (initial statements)")
                prompt = prompt_builder.build_round_zero()
                logger.debug(f"Round 0 prompt built: {prompt[:100]}...")
                round_responses = run_debate_with_retry(
                    max_rounds=max_rounds,
                    prompt=prompt,
                    images=images,
                    agents_ensemble=agents_ensemble,
                    output_dir=temp_dir,
                    round_num=i,
                    extract_func=extract_func,
                    json_mode=json_mode,
                    max_retries=max_retries,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    batch=batch,
                    batch_size=batch_size,
                )
            else:
                extracted_responses = [
                    response["response"] for response in all_responses[-1]
                ]
                logger.info(
                    f"Running debate round {i} with {len(extracted_responses)} previous responses"
                )
                logger.debug(
                    f"Extracted responses for round {i}: {extracted_responses}"
                )
                try:
                    if check_convergence(extracted_responses, extract_func):
                        logger.info(
                            f"Convergence detected after round {i-1}, ending debate early"
                        )
                        break
                except Exception as e:
                    logger.error(f"Error checking convergence: {str(e)}", exc_info=True)
                    raise

                pruned_responses = extracted_responses

                # Apply quality pruning if specified
                if quality_pruning_func:
                    logger.info(
                        f"Applying quality pruning for round {i} with amount={quality_pruning_amount}"
                    )
                    pruned_dir = temp_dir / "quality_pruned"
                    pruned_dir.mkdir(parents=True, exist_ok=True)

                    pruned_responses = quality_pruning_func(
                        responses=extracted_responses,
                        input=prompt_builder.query,
                        selected_amount=quality_pruning_amount,
                        round_number=i,
                        output_dir=pruned_dir,
                    )

                # Apply diversity pruning if specified
                if diversity_pruning_func:
                    logger.info(
                        f"Applying diversity pruning for round {i} with amount={diversity_pruning_amount}"
                    )
                    pruned_dir = temp_dir / "diversity_pruned"
                    pruned_dir.mkdir(parents=True, exist_ok=True)

                    pruned_responses = diversity_pruning_func(
                        responses=extracted_responses,
                        selected_amount=diversity_pruning_amount,
                        extract_func=extract_func,
                        round_number=i,
                        output_dir=pruned_dir,
                    )

                # Run the debate round with the pruned responses
                prompt = prompt_builder.build_round_n(pruned_responses)

                logger.debug(f"Round {i} prompt built: {prompt[:100]}...")
                round_responses = run_debate_with_retry(
                    max_rounds=max_rounds,
                    prompt=prompt,
                    images=images,
                    agents_ensemble=agents_ensemble,
                    output_dir=temp_dir,
                    round_num=i,
                    extract_func=extract_func,
                    json_mode=json_mode,
                    max_retries=max_retries,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    batch=batch,
                    batch_size=batch_size,
                )
            all_responses.append(round_responses)
            logger.info(
                f"Completed debate round {i} with {len(round_responses)} agent responses"
            )

        # Debate completed successfully, move files from temp_dir to output_dir
        file_count = len(list(temp_dir.glob("**/*")))
        logger.info(
            f"Debate completed successfully after {len(all_responses)} rounds, saving {file_count} files"
        )

        # Modified file copying section to handle directories
        for file_path in temp_dir.glob("*"):
            target_path = output_dir / file_path.name

            if file_path.is_dir():
                # Handle directories by using copytree for directories
                if target_path.exists():
                    # If target directory already exists, remove it first
                    if target_path.is_dir():
                        shutil.rmtree(target_path)
                    else:
                        target_path.unlink()  # Remove if it's a file
                # Copy the entire directory tree
                shutil.copytree(file_path, target_path)
                logger.debug(f"Saved debate directory: {target_path}")
            else:
                # Handle files with copy2 as before
                shutil.copy2(file_path, target_path)
                logger.debug(f"Saved debate file: {target_path}")

        return all_responses
    except Exception as e:
        logger.error(f"Error during debate: {str(e)}", exc_info=True)
        raise
    finally:
        # Clean up the temporary directory and its contents
        if temp_dir.exists():
            logger.debug(f"Cleaning up temporary directory: {temp_dir}")
            shutil.rmtree(temp_dir)


def run_debate_with_retry(
    max_rounds: int,
    prompt: str,
    agents_ensemble: AgentsEnsemble,
    output_dir: Union[str, Path],
    round_num: int,
    extract_func: Callable,
    images: Union[str, Path, List[str], List[Path], bytes, List[bytes], None] = None,
    json_mode: bool = False,
    batch: bool = False,
    batch_size: int = 11,
    max_retries: int = 3,
    temperature: float = 1.0,
    max_tokens: int = 6400,
) -> List[Dict]:
    """Run a debate round with retry capabilities.

    If extract_func raises an error, the function will retry the debate
    round up to max_retries times.

    Args:
        max_rounds: Maximum number of debate rounds.
        prompt: The debate prompt including previous context.
        agents_ensemble: Collection of LLM agents for the debate.
        output_dir: Directory path for saving debate responses.
        round_num: The current round number.
        extract_func: Function to process responses between rounds.
        images: Optional images for the debate round. Can be file paths, URLs,
               or raw image bytes.
        json_mode: Whether to expect JSON responses from agents.
        batch: Whether to run in batch mode.
        batch_size: Size of the batch.
        max_retries: Maximum retry attempts for the round.
        temperature: Sampling temperature for the model.
        max_tokens: Maximum number of tokens in the response.

    Returns:
        List[Dict]: List of agent responses from the round.

    Raises:
        RuntimeError: If maximum retries are exceeded.
        ValueError: If round_num is invalid.
    """
    if round_num < 0 or round_num >= max_rounds:
        logger.error(
            f"Invalid round number: {round_num}. Must be between 0 and {max_rounds - 1}."
        )
        raise ValueError(f"Round number must be between 0 and {max_rounds - 1}.")
    if max_retries < 1:
        logger.error("max_retries must be at least 1")
        raise ValueError("max_retries must be at least 1")

    # Ensure output_dir is a Path object
    output_dir = Path(output_dir)
    if not output_dir.exists():
        logger.debug(f"Creating output directory: {output_dir}")
        output_dir.mkdir(parents=True, exist_ok=True)
    else:
        logger.debug(f"Output directory already exists: {output_dir}")

    if not output_dir.is_dir():
        logger.error(f"Output path {output_dir} is not a directory.")
        raise ValueError(f"Output path {output_dir} must be a directory.")

    if extract_func is None:
        logger.error("No extract_func function provided for debate round")
        raise ValueError("extract_func function must be provided")

    # Log the start of the debate round with retries
    logger.info(f"Starting debate round {round_num} with max_retries={max_retries}")

    for attempt in range(1, max_retries + 1):
        try:
            # Run the appropriate debate function based on round number
            if round_num == 0:
                responses = run_debate_round_zero(
                    prompt=prompt,
                    images=images,
                    agents_ensemble=agents_ensemble,
                    output_dir=output_dir,
                    json_mode=json_mode,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    batch=batch,
                    batch_size=batch_size,
                )
            else:
                responses = run_debate_round_n(
                    prompt=prompt,
                    images=images,
                    agents_ensemble=agents_ensemble,
                    output_dir=output_dir,
                    round_num=round_num,
                    json_mode=json_mode,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    batch=batch,
                    batch_size=batch_size,
                )

            try:
                for response in responses:
                    extract_func(response["response"])
            except Exception as e:
                logger.warning(f"Error processing response with extract_func: {str(e)}")
                raise  # Re-raise to trigger retry

            logger.info(
                f"Debate round {round_num} completed successfully on attempt {attempt}"
            )
            return responses

        except Exception as e:
            if attempt < max_retries:
                logger.warning(
                    f"Error in debate round {round_num}, attempt {attempt}/{max_retries}: {str(e)}. "
                )
            else:
                logger.error(
                    f"Maximum retries ({max_retries}) exceeded for debate round {round_num}: {str(e)}"
                )
                raise RuntimeError(
                    f"Failed to complete debate round {round_num} after {max_retries} attempts"
                ) from e


def check_convergence(
    responses: List[Dict], extract_func: Optional[Callable] = None
) -> bool:
    """Check if the responses from all agents have converged to the same answer.

    Args:
        responses: List of agent responses from the most recent round of debate.
        extract_func: Function to process answers from responses. Defaults to
            None, in which case extract_bool_answer will be used.

    Returns:
        bool: True if all responses are the same, False otherwise.
    """
    # If extract_func is None, use extract_bool_answer as default
    if extract_func is None:
        logger.error("No extract_func function provided for convergence check")
        raise ValueError("extract_func function must be provided")

    try:
        answers = [extract_func(response) for response in responses]
        logger.debug(f"Processed answers for convergence check: {answers}")
        # Convert lists in answers to tuples for hashing
        hashable_answers = [
            tuple(ans) if isinstance(ans, list) else ans for ans in answers
        ]
        is_converged = len(set(hashable_answers)) == 1
        if is_converged:
            logger.info(
                f"Debate has converged on answer: {list(set(hashable_answers))[0]}"
            )
        return is_converged
    except Exception as e:
        logger.error(f"Error checking convergence: {str(e)}", exc_info=False)
        raise
