import os
import json
from datetime import datetime
from typing import List, Dict, Any
import numpy as np
from lead.core.multi_objective.individual import MultiObjectiveIndividual

class MultiObjectiveGenerationStorage:
    def __init__(self, problem_path: str):
        self.problem_path = problem_path
        self.storage_dir = os.path.join(problem_path, "multi_objective_history", 
                                      datetime.now().strftime("%Y%m%d_%H%M%S"))
        os.makedirs(self.storage_dir, exist_ok=True)
        self.log_file = os.path.join(problem_path, "multi_objective_evolution.log")
        self.objective_names = self._load_objective_names()
    
    def save_generation(self, generation: int, population: List[MultiObjectiveIndividual]):
        pareto_front = self._calculate_pareto_front(population)
        data = {
            "generation": generation,
            "population": [ind.to_dict() for ind in population],
            "pareto_front": [ind.to_dict() for ind in pareto_front],
            "objective_names": self.objective_names
        }
        file_path = os.path.join(self.storage_dir, f"generation_{generation}.json")
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
        self._log_generation(generation, pareto_front)
    
    def get_pareto_front(self) -> List[Dict[str, Any]]:
        all_fronts = []
        
        for file in os.listdir(self.storage_dir):
            if file.startswith("generation_") and file.endswith(".json"):
                with open(os.path.join(self.storage_dir, file), "r", encoding="utf-8") as f:
                    data = json.load(f)
                    all_fronts.extend(data["pareto_front"])

        unique_front = []
        seen_codes = set()
        for ind in all_fronts:
            if ind["code"] not in seen_codes:
                unique_front.append(ind)
                seen_codes.add(ind["code"])

        return self._calculate_global_pareto_front(unique_front)
    
    def _calculate_pareto_front(self, population: List[MultiObjectiveIndividual]) -> List[MultiObjectiveIndividual]:
        front = []
        for ind in population:
            is_dominated = False
            for other in population:
                if other.dominates(ind):
                    is_dominated = True
                    break
            if not is_dominated:
                front.append(ind)
        return front
    
    def _calculate_global_pareto_front(self, individuals: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        if not individuals:
            return []
            
        front = []
        for ind in individuals:
            is_dominated = False
            for other in individuals:
                dominates = True
                for obj in self.objective_names:
                    if other["fitnesses"][obj] > ind["fitnesses"][obj]:
                        dominates = False
                        break
                if dominates and any(other["fitnesses"][obj] < ind["fitnesses"][obj] for obj in self.objective_names):
                    is_dominated = True
                    break
            if not is_dominated:
                front.append(ind)
        return front
    
    def _load_objective_names(self) -> List[str]:
        config_path = os.path.join(self.problem_path, "problem_config.json")
        with open(config_path, "r", encoding="utf-8") as f:
            config = json.load(f)
            return config.get("objective_names", ["f1", "f2"])
    
    def _log_generation(self, generation: int, pareto_front: List[MultiObjectiveIndividual]):
        log_entry = (
            f"Generation {generation:04d} | "
            f"Pareto Front Size: {len(pareto_front)} | "
            f"Timestamp: {datetime.now().isoformat()}\n"
        )
        with open(self.log_file, "a", encoding="utf-8") as f:
            f.write(log_entry)