import re
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
from datatypes import ArchiveData
from tasks.base import BaseTask


def analyze_combined_coverage(
    combined_archive: Dict[str, Dict[str, bool]], top_k: int = 5
) -> Dict:
    """Analyze coverage metrics from a combined archive.

    Args:
        combined_archive: Combined archive mapping model paths to task success/failure
        top_k: Number of top models to analyze

    Returns:
        Dictionary containing coverage metrics
    """
    # Initialize coverage tracking
    coverage_stats = {
        "all_models": {
            "example_coverage": {},  # Maps example_id to whether any model passed it
            "total_examples": len(
                set().union(*[results.keys() for results in combined_archive.values()])
            ),
            "passed_examples": 0,
        },
        "top_models": {
            "example_coverage": {},  # Maps example_id to whether any top model passed it
            "total_examples": len(
                set().union(*[results.keys() for results in combined_archive.values()])
            ),
            "passed_examples": 0,
        },
    }

    # Initialize example coverage maps
    all_examples = set().union(
        *[results.keys() for results in combined_archive.values()]
    )
    for example_id in all_examples:
        coverage_stats["all_models"]["example_coverage"][example_id] = False
        coverage_stats["top_models"]["example_coverage"][example_id] = False

    # Get top k models by total number of passed examples
    model_scores = {
        model_path: sum(1 for passed in results.values() if passed)
        for model_path, results in combined_archive.items()
    }
    top_models = sorted(model_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
    top_model_paths = [model_path for model_path, _ in top_models]

    # Update coverage from all models
    for model_path, results in combined_archive.items():
        for example_id, passed in results.items():
            if passed:
                coverage_stats["all_models"]["example_coverage"][example_id] = True
                if model_path in top_model_paths:
                    coverage_stats["top_models"]["example_coverage"][example_id] = True

    # Calculate final statistics
    for category in ["all_models", "top_models"]:
        coverage_stats[category]["passed_examples"] = sum(
            1
            for passed in coverage_stats[category]["example_coverage"].values()
            if passed
        )
        coverage_stats[category]["coverage_ratio"] = (
            coverage_stats[category]["passed_examples"]
            / coverage_stats[category]["total_examples"]
        )

    return coverage_stats


def analyze_top_models_coverage(
    archive_map: Dict[str, Dict[Tuple[int], ArchiveData]],
    top_k: int,
    tasks: List[BaseTask],
) -> Dict:
    """Analyze example coverage for top models by overall_fitness.

    Args:
        archive_map: Archive map containing task data
        top_k: Number of top models to analyze
        tasks: List of task objects

    Returns:
        Dictionary containing coverage metrics
    """
    # First, get all unique models with their overall_fitness
    all_models = {}
    for task_archive in archive_map.values():
        for archive_data in task_archive.values():
            if (
                archive_data.model_path not in all_models
                and archive_data.overall_fitness is not None
            ):
                all_models[archive_data.model_path] = (
                    archive_data.overall_fitness,
                    archive_data.skill_vector,
                )

    # Sort models by overall_fitness and get top k
    top_models = sorted(all_models.items(), key=lambda x: x[1][0], reverse=True)[:top_k]

    # Initialize coverage tracking
    coverage_stats = {
        "example_coverage": {},  # Maps example_id to whether any top model passed it
        "total_examples": 0,
        "passed_examples": 0,
    }

    # Initialize example coverage map
    example_offset = 0
    for task in tasks:
        task_example_ids = task.get_example_ids("train")
        for example_id in task_example_ids:
            combined_key = f"{task.task_name}_{example_id}"
            coverage_stats["example_coverage"][combined_key] = False
            coverage_stats["total_examples"] += 1

    # Check coverage from top models
    for model_path, (_, skill_vector) in top_models:
        # For each example in the skill vector, check if it's solved
        example_offset = 0
        for task in tasks:
            task_example_ids = task.get_example_ids("train")
            if len(skill_vector) > example_offset:
                task_end = min(
                    example_offset + len(task_example_ids), len(skill_vector)
                )
                task_skills = skill_vector[example_offset:task_end]

                # Update coverage for each example
                for i, (example_id, passed) in enumerate(
                    zip(task_example_ids, task_skills)
                ):
                    if passed:
                        combined_key = f"{task.task_name}_{example_id}"
                        coverage_stats["example_coverage"][combined_key] = True

            example_offset += len(task_example_ids)

    # Calculate final statistics
    coverage_stats["passed_examples"] = sum(
        1 for passed in coverage_stats["example_coverage"].values() if passed
    )
    coverage_stats["coverage_ratio"] = (
        coverage_stats["passed_examples"] / coverage_stats["total_examples"]
        if coverage_stats["total_examples"] > 0
        else 0.0
    )

    return coverage_stats


def compute_coverage_metrics(
    archive_data, tasks, run_dns, len_subset_skill_vector=None, max_samples=None
):
    """Compute coverage metrics for wandb logging based on optimization mode.

    Args:
        archive_data: Archive data
        tasks: List of task objects
        run_dns: Boolean indicating if DNS is being run
        len_subset_skill_vector: Optional length of subset skill vector
        max_samples: Optional maximum number of samples

    Returns:
        Coverage metrics
    """
    coverage_metrics = {}

    if run_dns:
        # Analyze example coverage for DNS archive
        coverage_stats = {}
        example_offset = 0

        # Get top k models by fitness for top model analysis
        top_k = 5
        sorted_models = sorted(
            archive_data["dns_archive"], key=lambda x: x.fitness, reverse=True
        )[:top_k]

        for task in tasks:
            task_example_ids = task.get_example_ids("train")
            init_dict_entry = {
                "example_coverage": {ex_id: False for ex_id in task_example_ids},
                "total_examples": len(task_example_ids),
                "passed_examples": 0,
            }
            coverage_stats[task.task_name] = {
                "all_models": init_dict_entry,
                "top_models": init_dict_entry,
            }

            num_examples = len(task_example_ids)

            # Update coverage from all models
            for solution in archive_data["dns_archive"]:
                task_skills = solution.skill_vector[
                    example_offset:example_offset + num_examples
                ]

                for example_id, passed in zip(task_example_ids, task_skills):
                    if passed:
                        coverage_stats[task.task_name]["all_models"][
                            "example_coverage"
                        ][example_id] = True
                        if solution in sorted_models:
                            coverage_stats[task.task_name]["top_models"][
                                "example_coverage"
                            ][example_id] = True

            example_offset += num_examples

        # Calculate final statistics
        for task_name, stats in coverage_stats.items():
            # All models coverage
            stats["all_models"]["passed_examples"] = sum(
                1
                for passed in stats["all_models"]["example_coverage"].values()
                if passed
            )
            stats["all_models"]["coverage_ratio"] = (
                stats["all_models"]["passed_examples"]
                / stats["all_models"]["total_examples"]
                if stats["all_models"]["total_examples"] > 0
                else 0.0
            )

            # Top models coverage
            stats["top_models"]["passed_examples"] = sum(
                1
                for passed in stats["top_models"]["example_coverage"].values()
                if passed
            )
            stats["top_models"]["coverage_ratio"] = (
                stats["top_models"]["passed_examples"]
                / stats["top_models"]["total_examples"]
                if stats["top_models"]["total_examples"] > 0
                else 0.0
            )

            coverage_metrics.update(
                {
                    f"example_coverage/{task_name}/all_models/passed_ratio": stats[
                        "all_models"
                    ]["coverage_ratio"],
                    f"example_coverage/{task_name}/all_models/passed_examples": stats[
                        "all_models"
                    ]["passed_examples"],
                    f"example_coverage/{task_name}/all_models/total_examples": stats[
                        "all_models"
                    ]["total_examples"],
                    f"example_coverage/{task_name}/top_models/passed_ratio": stats[
                        "top_models"
                    ]["coverage_ratio"],
                    f"example_coverage/{task_name}/top_models/passed_examples": stats[
                        "top_models"
                    ]["passed_examples"],
                    f"example_coverage/{task_name}/top_models/total_examples": stats[
                        "top_models"
                    ]["total_examples"],
                }
            )

        # If adaptive tasks are enabled, analyze subset coverage
        if len_subset_skill_vector is not None:
            subset_coverage_stats = {}
            example_offset = 0
            example_count = 0

            for task in tasks:
                task_example_ids = task.get_example_ids("train")
                remaining_examples = len_subset_skill_vector - example_count
                if remaining_examples <= 0:
                    continue

                task_subset_ids = task_example_ids[
                    : min(len(task_example_ids), remaining_examples)
                ]
                init_dict_entry = {
                    "example_coverage": {ex_id: False for ex_id in task_subset_ids},
                    "total_examples": len(task_subset_ids),
                    "passed_examples": 0,
                }
                subset_coverage_stats[task.task_name] = {
                    "all_models": init_dict_entry,
                    "top_models": init_dict_entry,
                }

                num_examples = len(task_subset_ids)

                for solution in archive_data["dns_archive"]:
                    task_skills = solution.skill_vector[
                        example_offset:example_offset + num_examples
                    ]

                    for example_id, passed in zip(task_subset_ids, task_skills):
                        if passed:
                            subset_coverage_stats[task.task_name]["all_models"][
                                "example_coverage"
                            ][example_id] = True
                            if solution in sorted_models:
                                subset_coverage_stats[task.task_name]["top_models"][
                                    "example_coverage"
                                ][example_id] = True

                example_offset += num_examples
                example_count += len(task_example_ids)

            # Calculate subset statistics
            for task_name, stats in subset_coverage_stats.items():
                stats["all_models"]["passed_examples"] = sum(
                    1
                    for passed in stats["all_models"]["example_coverage"].values()
                    if passed
                )
                stats["all_models"]["coverage_ratio"] = (
                    stats["all_models"]["passed_examples"]
                    / stats["all_models"]["total_examples"]
                    if stats["all_models"]["total_examples"] > 0
                    else 0.0
                )

                stats["top_models"]["passed_examples"] = sum(
                    1
                    for passed in stats["top_models"]["example_coverage"].values()
                    if passed
                )
                stats["top_models"]["coverage_ratio"] = (
                    stats["top_models"]["passed_examples"]
                    / stats["top_models"]["total_examples"]
                    if stats["top_models"]["total_examples"] > 0
                    else 0.0
                )

                coverage_metrics.update(
                    {
                        f"subset_example_coverage/{task_name}/all_models/passed_ratio": stats[
                            "all_models"
                        ][
                            "coverage_ratio"
                        ],
                        f"subset_example_coverage/{task_name}/all_models/passed_examples": stats[
                            "all_models"
                        ][
                            "passed_examples"
                        ],
                        f"subset_example_coverage/{task_name}/all_models/total_examples": stats[
                            "all_models"
                        ][
                            "total_examples"
                        ],
                        f"subset_example_coverage/{task_name}/top_models/passed_ratio": stats[
                            "top_models"
                        ][
                            "coverage_ratio"
                        ],
                        f"subset_example_coverage/{task_name}/top_models/passed_examples": stats[
                            "top_models"
                        ][
                            "passed_examples"
                        ],
                        f"subset_example_coverage/{task_name}/top_models/total_examples": stats[
                            "top_models"
                        ][
                            "total_examples"
                        ],
                    }
                )

            coverage_metrics.update(
                {
                    "subset_example_coverage/subset_size": len_subset_skill_vector,
                    "subset_example_coverage/max_size": max_samples,
                }
            )

        # Create and analyze combined archive coverage
        combined_coverage = {
            solution.model_path: {} for solution in archive_data["dns_archive"]
        }

        example_offset = 0
        for task in tasks:
            task_example_ids = task.get_example_ids("train")
            num_examples = len(task_example_ids)

            for solution in archive_data["dns_archive"]:
                task_skills = solution.skill_vector[
                    example_offset:example_offset + num_examples
                ]

                for example_id, passed in zip(task_example_ids, task_skills):
                    combined_key = f"{task.task_name}_{example_id}"
                    combined_coverage[solution.model_path][combined_key] = passed

            example_offset += num_examples

        combined_coverage_stats = analyze_combined_coverage(combined_coverage)
        coverage_metrics.update(
            {
                "example_coverage/combined/all_models/passed_ratio": combined_coverage_stats[
                    "all_models"
                ][
                    "coverage_ratio"
                ],
                "example_coverage/combined/all_models/passed_examples": combined_coverage_stats[
                    "all_models"
                ][
                    "passed_examples"
                ],
                "example_coverage/combined/all_models/total_examples": combined_coverage_stats[
                    "all_models"
                ][
                    "total_examples"
                ],
                "example_coverage/combined/top_models/passed_ratio": combined_coverage_stats[
                    "top_models"
                ][
                    "coverage_ratio"
                ],
                "example_coverage/combined/top_models/passed_examples": combined_coverage_stats[
                    "top_models"
                ][
                    "passed_examples"
                ],
                "example_coverage/combined/top_models/total_examples": combined_coverage_stats[
                    "top_models"
                ][
                    "total_examples"
                ],
            }
        )

    else:
        # QD mode coverage metrics
        top5_coverage = analyze_top_models_coverage(
            archive_data["archive_map"], 5, tasks
        )
        top20_coverage = analyze_top_models_coverage(
            archive_data["archive_map"], 20, tasks
        )

        coverage_metrics = {
            "example_coverage/combined/top_models/passed_ratio": top5_coverage[
                "coverage_ratio"
            ],
            "example_coverage/combined/top_models/passed_examples": top5_coverage[
                "passed_examples"
            ],
            "example_coverage/combined/top_models/total_examples": top5_coverage[
                "total_examples"
            ],
            "example_coverage/combined/all_models/passed_ratio": top20_coverage[
                "coverage_ratio"
            ],
            "example_coverage/combined/all_models/passed_examples": top20_coverage[
                "passed_examples"
            ],
            "example_coverage/combined/all_models/total_examples": top20_coverage[
                "total_examples"
            ],
        }

        # Add best overall fitness model metrics
        best_model_path = find_highest_overall_fitness_model(
            archive_data["archive_map"]
        )
        if best_model_path:
            best_model_fitness = 0.0
            for task_archive in archive_data["archive_map"].values():
                for archive_data in task_archive.values():
                    if (
                        archive_data.model_path == best_model_path
                        and archive_data.overall_fitness is not None
                    ):
                        best_model_fitness = archive_data.overall_fitness
                        break
            coverage_metrics["overall_fitness/best_model"] = best_model_fitness

    return coverage_metrics


def compute_overall_fitness(skill_vector: List[bool]) -> float:
    """Compute overall fitness score from a skill vector.

    Args:
        skill_vector: Binary skill vector representing task success/failure

    Returns:
        Overall fitness score based on number of tasks succeeded
    """
    if not skill_vector:
        return 0.0
    return sum(skill_vector) / len(skill_vector)


def extract_max_quality_map(data):
    """Extract the maximum quality data from the archive map."""
    max_quality_data = {}
    for key, values in data.items():
        sorted_values = sorted(values.items(), key=lambda x: x[1].quality, reverse=True)
        max_quality = sorted_values[0][1].quality
        max_quality_entries = {
            k: v for k, v in sorted_values if v.quality == max_quality
        }
        max_quality_data[key] = max_quality_entries
    return max_quality_data


def get_elite_values(data, data_split):
    """Get elite values from archive data based on the specified data split.

    Args:
        data: Archive map data
        data_split: Data split to use ('all', 'train', or 'validation')

    Returns:
        Dictionary mapping task names to their elite values
    """
    if data_split in ["all", "train"]:
        max_quality_train = defaultdict(lambda: float("-inf"))
        for key, values in data.items():
            for bc_ids, archive_data in values.items():
                if archive_data.quality > max_quality_train[key]:
                    max_quality_train[key] = archive_data.quality
        return max_quality_train

    elif data_split == "validation":
        max_archive_map = extract_max_quality_map(data)
        top_elite_models = {}
        for task_name, task_data in max_archive_map.items():
            for key, data in task_data.items():
                # Extract generation number
                match = re.search(r"gen_(\d+)", data.model_path)
                if match:
                    gen_number = int(match.group(1))
                else:
                    gen_number = 100000000
                if (
                    task_name not in top_elite_models
                    or gen_number < top_elite_models[task_name]["gen_number"]
                ):
                    top_elite_models[task_name] = {
                        "key": key,
                        "model_path": data.model_path,
                        "gen_number": gen_number,
                        "validation_quality": data.validation_quality,
                    }
        return {
            key: top_elite_models[key]["validation_quality"] for key in top_elite_models
        }
    else:
        raise NotImplementedError(f"Data split {data_split} not implemented")


def find_highest_overall_fitness_model(
    archive_map: Dict[str, Dict[Tuple[int], ArchiveData]]
) -> Optional[str]:
    """Find the model with the highest overall fitness across all tasks.

    Args:
        archive_map: Archive map containing all task data

    Returns:
        Path to the model with highest overall fitness, or None if no models found
    """
    best_model = None
    best_fitness = float("-inf")

    for task_archive in archive_map.values():
        for archive_data in task_archive.values():
            if (
                archive_data.overall_fitness is not None
                and archive_data.overall_fitness > best_fitness
            ):
                best_fitness = archive_data.overall_fitness
                best_model = archive_data.model_path

    return best_model


def compute_acd_coverage_metrics(
    archive_data, tasks, cfg=None, threshold: float = 0.5, validation_tasks=None
) -> Dict:
    """Compute coverage metrics for wandb logging for ACD task vectors.

    Args:
        archive_data: Archive data containing dns_archive with ACD solutions
        tasks: List of task objects
        cfg: Optional configuration object (for backward compatibility)
        threshold: Score threshold to consider a task passed (default: 0.5)
        validation_tasks: Optional list of validation task names

    Returns:
        Coverage metrics dictionary for wandb logging
    """
    coverage_metrics = {}

    # Check if we have a valid archive
    if not archive_data or "dns_archive" not in archive_data or not archive_data["dns_archive"]:
        return coverage_metrics

    # Create combined coverage representation for ACD tasks
    combined_coverage = {}
    
    # Get all task_ids from all models to determine total task coverage
    all_task_ids = set()
    for solution in archive_data["dns_archive"]:
        if not solution or not hasattr(solution, "acd_skill_vector") or not solution.acd_skill_vector:
            continue
        all_task_ids.update(solution.acd_skill_vector.keys())
        
        # Initialize combined coverage for this model
        combined_coverage[solution.model_path] = {}
        
    # If no task IDs found, return empty metrics
    if not all_task_ids:
        return coverage_metrics
        
    # Fill combined coverage with boolean pass/fail based on threshold
    for solution in archive_data["dns_archive"]:
        if not solution or not solution.acd_skill_vector:
            continue
            
        for task_id in all_task_ids:
            # Default to 0.0 if task_id not in the skill vector
            score = solution.acd_skill_vector.get(task_id, 0.0)
            combined_coverage[solution.model_path][task_id] = (score >= threshold)
    
    # Analyze the combined coverage with existing function
    top_k = 5
    combined_coverage_stats = analyze_combined_coverage(combined_coverage, top_k)
    
    if combined_coverage_stats:
        coverage_metrics.update({
            "acd_coverage/combined/all_models/passed_ratio": combined_coverage_stats["all_models"]["coverage_ratio"],
            "acd_coverage/combined/all_models/passed_examples": combined_coverage_stats["all_models"]["passed_examples"],
            "acd_coverage/combined/all_models/total_examples": combined_coverage_stats["all_models"]["total_examples"],
            "acd_coverage/combined/top_models/passed_ratio": combined_coverage_stats["top_models"]["coverage_ratio"],
            "acd_coverage/combined/top_models/passed_examples": combined_coverage_stats["top_models"]["passed_examples"],
            "acd_coverage/combined/top_models/total_examples": combined_coverage_stats["top_models"]["total_examples"],
        })
        
    # If validation tasks are present, include validation metrics for top models
    if validation_tasks or (cfg and hasattr(cfg, "validation_tasks") and cfg.validation_tasks):
        # Get the top 5 solutions for validation quality analysis
        valid_solutions = [s for s in archive_data["dns_archive"] if s and hasattr(s, "fitness")]
        sorted_solutions = sorted(valid_solutions, key=lambda x: x.fitness, reverse=True)[:5]
        
        # Track validation qualities for each top model
        for idx, solution in enumerate(sorted_solutions):
            if hasattr(solution, "validation_quality") and solution.validation_quality is not None:
                coverage_metrics[f"validation/top{idx+1}_model/quality"] = solution.validation_quality
        
        # If any models have validation_quality, find the best one
        validation_models = [s for s in sorted_solutions if hasattr(s, "validation_quality") and s.validation_quality is not None]
        if validation_models:
            best_validation_model = max(validation_models, key=lambda x: x.validation_quality)
            coverage_metrics["validation/best_validation_model/quality"] = best_validation_model.validation_quality
    
    return coverage_metrics
