from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Literal
import numpy as np

from ..config import HeuPSROConfig
from .metrics import (
    calculate_code_similarity,
    calculate_behavioral_similarity_solver,
    calculate_behavioral_similarity_generator,
    calculate_combined_similarity,
    config_to_string
)


@dataclass
class DecisionResult:
    action: Literal["ADD", "REJECT"]
    similar_idx: Optional[int] = None
    similarity: Optional[float] = None
    reason: Optional[str] = None


class PoolDiversityManager:
    def __init__(self, cfg: HeuPSROConfig, pools, meta):
        self.cfg = cfg
        self.pools = pools
        self.meta = meta
        self.stats = {"added": 0, "replaced": 0, "rejected": 0}
    
    def check_solver_diversity(
        self, 
        code: str, 
        algorithm: str, 
        metadata: Dict
    ) -> DecisionResult:
        if self.pools.n_solvers == 0:
            return DecisionResult("ADD", reason="First solver")
        
        max_code_sim = 0.0
        max_behavior_sim = 0.0
        similar_idx = -1
        
        for idx, solver in enumerate(self.pools.solver_pool):
            code_sim = calculate_code_similarity(code, solver.code)
            max_code_sim = max(max_code_sim, code_sim)
            
            behavior_sim = self._calculate_solver_behavior_similarity(idx)
            max_behavior_sim = max(max_behavior_sim, behavior_sim)
            
            combined_sim = calculate_combined_similarity(
                code_sim, behavior_sim,
                self.cfg.psro_diversity_weight_code,
                self.cfg.psro_diversity_weight_behavior
            )
            if combined_sim > max_code_sim:  
                similar_idx = idx
        
        combined_sim = calculate_combined_similarity(
            max_code_sim, max_behavior_sim,
            self.cfg.psro_diversity_weight_code,
            self.cfg.psro_diversity_weight_behavior
        )
        
        if combined_sim > self.cfg.psro_code_similarity_threshold:
            return DecisionResult("REJECT", similar_idx, combined_sim, "Too similar to existing solver")
        
        max_size = self.cfg.psro_max_pool_size_solver
        if max_size is not None and self.pools.n_solvers >= max_size:
            return DecisionResult("REJECT", None, combined_sim, "Pool full")
        
        return DecisionResult("ADD", reason="Diverse enough")
    
    def check_generator_diversity(
        self, 
        config: Dict, 
        algorithm: str, 
        metadata: Dict
    ) -> DecisionResult:
        if self.pools.n_generators == 0:
            return DecisionResult("ADD", reason="First generator")
        
        config_str = config_to_string(config)
        max_sim = 0.0
        similar_idx = -1
        
        for idx, generator in enumerate(self.pools.generator_pool):
            other_config = {"code": generator.code, **generator.params}
            other_config_str = config_to_string(other_config)
            config_sim = calculate_code_similarity(config_str, other_config_str)
            
            behavior_sim = self._calculate_generator_behavior_similarity(idx)
            
            combined_sim = calculate_combined_similarity(
                config_sim, behavior_sim,
                self.cfg.psro_diversity_weight_code,
                self.cfg.psro_diversity_weight_behavior
            )
            
            if combined_sim > max_sim:
                max_sim = combined_sim
                similar_idx = idx
        
        if max_sim > self.cfg.psro_behavioral_diversity_threshold:
            return DecisionResult("REJECT", similar_idx, max_sim, "Too similar to existing generator")
        
        max_size = self.cfg.psro_max_pool_size_generator
        if max_size is not None and self.pools.n_generators >= max_size:
            return DecisionResult("REJECT", None, max_sim, "Pool full")
        
        return DecisionResult("ADD", reason="Diverse enough")
    
    def _calculate_solver_behavior_similarity(self, idx: int) -> float:
        if idx >= self.meta.utilities.shape[0]:
            return 0.0
        
        behavior_vec = self.meta.utilities[idx, :]
        
        similarities = []
        for other_idx in range(self.meta.utilities.shape[0]):
            if other_idx != idx:
                other_vec = self.meta.utilities[other_idx, :]
                sim = calculate_behavioral_similarity_solver(behavior_vec, other_vec)
                similarities.append(sim)
        
        return np.mean(similarities) if similarities else 0.0
    
    def _calculate_generator_behavior_similarity(self, idx: int) -> float:
        if idx >= self.meta.utilities.shape[1]:
            return 0.0
        
        behavior_vec = self.meta.utilities[:, idx]
        
        similarities = []
        for other_idx in range(self.meta.utilities.shape[1]):
            if other_idx != idx:
                other_vec = self.meta.utilities[:, other_idx]
                sim = calculate_behavioral_similarity_generator(behavior_vec, other_vec)
                similarities.append(sim)
        
        return np.mean(similarities) if similarities else 0.0

