"""Simplified unified configuration with shared parameters and timestamp-based result organization."""

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from pathlib import Path


@dataclass
class ModelConfig:
    """Minimal model configuration for compatibility with existing code."""
    name: str
    provider: str
    temperature: float = 0.7
    max_tokens: int = 2048
    random_seed: Optional[int] = None


@dataclass
class EvaluationTask:
    """Represents a single evaluation task."""
    task_id: str
    problem_index: int
    model_config: ModelConfig
    problem_data: Dict[str, Any]
    attempt: int = 0
    max_attempts: int = 3
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __hash__(self):
        """Make task hashable for use in sets."""
        return hash(f"{self.task_id}_{self.problem_index}_{self.model_config.name}")


@dataclass
class EvaluationResult:
    """Result of a single evaluation."""
    task_id: str
    problem_index: int
    model_name: str
    success: bool
    answer: Optional[str] = None
    ground_truth: Optional[str] = None
    is_correct: Optional[bool] = None
    error: Optional[str] = None
    metrics: Dict[str, Any] = field(default_factory=dict)
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    duration: float = 0.0
    attempt: int = 0


@dataclass
class EvaluationParams:
    """Complete evaluation parameters for a single setting."""
    # Required unique identifier
    setting_id: str
    
    # Agent configuration
    agent_type: str = "react"  # react, deep_research, etc.
    agent_stop_type: str = "default"  # default, interaction_scaling, etc.
    agent_stop_kwargs: Dict[str, Any] = field(default_factory=dict)  # Additional stopping configuration
    
    # Model parameters
    model_name: str = "claude-3-sonnet-20240229"
    model_provider: Optional[str] = None  # Auto-inferred from model_name if None
    temperature: float = 0.7
    max_tokens: int = 2048
    random_seed: Optional[int] = None  # Random seed for reproducible inference
    
    # Dataset parameters
    dataset_type: str = "full"  # full, medium, small
    dataset_dir: str = "./data"
    num_samples: Optional[int] = None  # None means use full dataset
    start_index: Optional[int] = None  # Starting index for dataset slice
    end_index: Optional[int] = None    # Ending index for dataset slice
    
    # Search parameters
    search_engine_type: str = "chromadb"
    chromadb_base_path: str = "./databases/chroma_db"
    collection_name: str = None
    embedding_model: str = "default"
    results_per_page: int = 5
    max_documents: int = 100
    
    # Evaluation parameters
    num_workers: int = 4
    retry_attempts: int = 3
    retry_delay: float = 1.0
    chunk_size: int = 10
    
    # Chunked checkpoint parameters
    enable_continual_evaluation: bool = True  # Enable timestamp and checkpoint-level continual eval
    
    # Deep research specific configuration
    deep_research_config: Dict[str, Any] = field(default_factory=dict)
    
    # Legacy parameter support (mapped to chunk_size)
    checkpoint_interval: Optional[int] = None
    
    # Additional metadata
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        """Auto-infer model provider if not set and handle legacy parameters."""
        if self.model_provider is None:
            self.model_provider = self._infer_provider()
        
        # Handle legacy checkpoint_interval parameter
        if self.checkpoint_interval is not None:
            self.chunk_size = self.checkpoint_interval
    
    def _infer_provider(self) -> str:
        """Infer provider from model name."""
        if 'claude' in self.model_name.lower():
            return 'anthropic'
        elif 'gpt' in self.model_name.lower():
            return 'openai'
        elif 'gemini' in self.model_name.lower():
            return 'google'
        elif 'grok' in self.model_name.lower():
            return 'grok'
        elif 'llama' in self.model_name.lower():
            return 'meta'
        elif 'deepseek' in self.model_name.lower():
            return 'deepseek'
        elif 'qwen' in self.model_name.lower():
            return 'damo'
        else:
            raise ValueError(f"Unknown model provider for {self.model_name}")
    
    @property
    def dataset_path(self) -> str:
        """Get dataset path based on type and samples."""
        if self.num_samples:
            # Subsampled dataset
            return f"{self.dataset_dir}/12_GSM_eval/final_eval_data_{self.dataset_type}_{self.num_samples}.json"
        else:
            # Full dataset
            return f"{self.dataset_dir}/11_GSM_final_database/final_database_{self.dataset_type}.json"
    
    @property
    def chromadb_path(self) -> str:
        """Get chromadb path based on dataset type."""
        return f"{self.chromadb_base_path}/{self.dataset_type}"
    
    def get_display_name(self) -> str:
        """Get human-readable name for this setting."""
        return f"{self.model_name} on {self.dataset_type}"


@dataclass
class UnifiedBatchConfig:
    """Simplified unified batch configuration with shared parameters and timestamp-based results."""
    
    # Batch run name (used as folder name)
    batch_run_name: str
    base_output_dir: str = "results"
    num_workers: int = 4
    # List of evaluation parameters (fully resolved)
    evaluation_params: List[EvaluationParams] = field(default_factory=list)
    
    # Evaluation control parameters
    test_mode: bool = False
    test_samples: int = 3
    
    # System prompt
    system_prompt: str = """You are a problem-solving agent. Your mission is to solve math word problems by finding all necessary facts (premises) using search tools. Your reasoning process is the most important part.

1. How to Answer
Provisional Answer: During your reasoning steps, if you calculate a temporary answer based on incomplete information, start that line with ****.

**** 42

Final Answer: Your final, conclusive answer must begin with #### and contain only the numerical solution.

#### 42

If you're unsure: #### I don't know

2. Available Tools
You have two tools to find premises:

search(query: string)
Searches for premises using keywords. Returns up to 5 relevant premises. IMPORTANT: Do not include names (e.g., "Nancy") in your query. The search tool doesn't use them. Instead, check the Metadata in the search results to see who is speaking.

next_page()
Gets the next 5 results for your last search query. Use this if you suspect there's more similar information.

3. Core Strategy: The Research Loop
Think of your work as a loop: Analyze -> Search -> Attempt to Solve. Repeat until you have enough information.

Step 1: Analyze the Gaps
Quickly ask yourself:

- What is the main question?
- What facts do I have?
- What crucial information is missing?
- What may be the whole story?

Step 2: Plan and Execute Your Search
Decide your next action based on the quality of your last search results: If results are relevant and promising......but seem incomplete (e.g., you found a partial list), use next_page() to get more. If you find some information and could do a better query search, then start a new search with search(query). Use keywords related to your missing information. Think associatively (e.g., if you see "top shelf," search for "bottom shelf").

Key Idea: Carefully balance "next_page" and "search(query)".

Step 3: Attempt to Solve
Using only the premises you've gathered, try to calculate the answer. Show your work. State your provisional answer on a new line, starting with ****. If you can solve it confidently, proceed to the final answer format. If you're stuck, identify what's still missing and go back to Step 1.

4. Example Walkthrough
Problem: "How many books did Nancy have on the book cart when she started?"

Iteration 1
- Analyze: Need the total number of books on the cart.
- Search: search("book cart")
- Results: Got 5 premises listing book types (history, romance, western, biographies) and a new clue: "Half the books on the bottom section were mystery books."
- Attempt: I can add the books I know: 12 + 8 + 5 + 6 = 31. But the "half mystery books" clue is a dangling premise. I can't solve the full problem yet.
**** 31

Next Action: The results were relevant, so I'll check for more with next_page().

Iteration 2
- Search: next_page() -> Results: "We have iterated through all the pages."
Analyze: That query is done. The clue mentioned a "bottom section," which implies a "top section." I need to investigate that.
- Search: search("top section")
- Results: Got a new premise: "4 poetry books from the top section."
- Attempt to Solve: Now I can piece everything together.
- Calculation:
Top section: 12 history + 8 romance + 4 poetry = 24 books.
Bottom section: The known books are 5 western + 6 biographies = 11 books. This must be one half, so the other half is 11 mystery books. Total = 11 + 11 = 22 books.
Total Cart: 24 (top) + 22 (bottom) = 46 books.
**** 46

Conclusion: All premises are used and the logic holds. I'm ready to answer.

5. Concluding and Final Answer
When you are confident, format your final response like this. Remember to use #### only for the final answer, not for intermediate calculations.

My calculation:
Top section: 12 history + 8 romance + 4 poetry = 24 books
Bottom section: 5 western + 6 biographies + 11 mystery = 22 books
Total: 24 + 22 = 46 books

Confidence: High
#### 46"""
    
    @classmethod
    def from_yaml(cls, yaml_data: dict) -> 'UnifiedBatchConfig':
        """Create config from YAML data.
        
        Args:
            yaml_data: Dictionary from parsed YAML
            
        Returns:
            UnifiedBatchConfig instance
        """
        # Get batch run name and output dir
        batch_run_name = yaml_data.get('batch_run_name', 'evaluation')
        base_output_dir = yaml_data.get('base_output_dir', 'results')

        num_workers = yaml_data.get('num_workers', 4)
        # Get shared parameters as defaults
        shared_params = yaml_data.get('shared_parameters', {})
        
        # Load evaluation control parameters
        eval_control = yaml_data.get('evaluation_control', {})
        test_mode = eval_control.get('test_mode', False)
        test_samples = eval_control.get('test_samples', 3)
        
        # Load system prompt
        prompts = yaml_data.get('prompts', {})
        system_prompt = prompts.get('system_prompt', cls.system_prompt)
        
        # Create evaluation params by combining shared parameters with overrides
        evaluation_params = []
        settings_data = yaml_data.get('evaluation_settings', [])
        
        for setting_dict in settings_data:
            setting_id = setting_dict.get('setting_id', '')
            if not setting_id:
                raise ValueError("Each evaluation setting must have a 'setting_id'")
            
            # Start with shared parameters
            params = shared_params.copy()
            
            # Extract metadata separately
            metadata = setting_dict.get('metadata', {})
            
            # Apply overrides (everything except setting_id and metadata)
            for key, value in setting_dict.items():
                if key not in ('setting_id', 'metadata'):
                    params[key] = value
            
            # Add metadata
            params['metadata'] = metadata
            
            # Create EvaluationParams instance
            eval_params = EvaluationParams(setting_id=setting_id, **params)
            evaluation_params.append(eval_params)
        return cls(
            batch_run_name=batch_run_name,
            base_output_dir=base_output_dir,
            evaluation_params=evaluation_params,
            test_mode=test_mode,
            test_samples=test_samples,
            system_prompt=system_prompt,
            num_workers=num_workers,
        )
    
    def get_output_dir(self) -> Path:
        """Get the output directory for this batch run.
        
        Returns:
            Path to output directory: base_output_dir/batch_run_name/
        """
        return Path(self.base_output_dir).expanduser() / self.batch_run_name
    
    def get_setting_output_dir(self, setting_id: str, timestamp: Optional[str] = None) -> Path:
        """Get output directory for a specific setting.
        
        Args:
            setting_id: ID of the evaluation setting
            timestamp: Optional timestamp string (defaults to current time)
            
        Returns:
            Path to setting output: base_output_dir/batch_run_name/setting_id/timestamp/
        """
        if timestamp is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        return self.get_output_dir() / setting_id / timestamp
    
    def get_existing_timestamps(self, setting_id: str) -> List[str]:
        """Get list of existing timestamp directories for a setting.
        
        Args:
            setting_id: ID of the evaluation setting
            
        Returns:
            List of timestamp strings (sorted, newest first)
        """
        setting_dir = self.get_output_dir() / setting_id
        if not setting_dir.exists():
            return []
        
        timestamps = []
        for item in setting_dir.iterdir():
            if item.is_dir() and self._is_timestamp_format(item.name):
                timestamps.append(item.name)
        
        # Sort newest first
        timestamps.sort(reverse=True)
        return timestamps
    
    def _is_timestamp_format(self, name: str) -> bool:
        """Check if a directory name matches timestamp format.
        
        Args:
            name: Directory name to check
            
        Returns:
            True if matches YYYYMMDD_HHMMSS format
        """
        try:
            datetime.strptime(name, "%Y%m%d_%H%M%S")
            return True
        except ValueError:
            return False
    
    def group_params_by_model(self) -> Dict[str, List[EvaluationParams]]:
        """Group evaluation parameters by model name.
        
        Returns:
            Dictionary mapping model names to lists of evaluation parameters
        """
        groups = {}
        for params in self.evaluation_params:
            if params.model_name not in groups:
                groups[params.model_name] = []
            groups[params.model_name].append(params)
        return groups
    
    def get_params_by_id(self, setting_id: str) -> Optional[EvaluationParams]:
        """Get evaluation parameters by setting ID.
        
        Args:
            setting_id: ID of the evaluation setting
            
        Returns:
            EvaluationParams instance or None if not found
        """
        for params in self.evaluation_params:
            if params.setting_id == setting_id:
                return params
        return None
    
    @property
    def experiment_name(self) -> str:
        """Get experiment name for compatibility."""
        return self.batch_run_name
    
    @property 
    def evaluation(self):
        """Get evaluation config for compatibility."""
        class EvaluationConfig:
            def __init__(self, output_dir: str):
                self.output_dir = output_dir
        return EvaluationConfig(str(self.get_output_dir()))
    
    @property
    def prompts(self):
        """Get prompts config for compatibility."""
        class PromptsConfig:
            def __init__(self, system_prompt: str):
                self.system_prompt = system_prompt
        return PromptsConfig(self.system_prompt)
    
    def to_dict(self) -> dict:
        """Convert config to dictionary for serialization."""
        return {
            'batch_run_name': self.batch_run_name,
            'base_output_dir': self.base_output_dir,
            'evaluation_control': {
                'test_mode': self.test_mode,
                'test_samples': self.test_samples
            },
            'prompts': {
                'system_prompt': self.system_prompt
            },
            'evaluation_params': [
                {
                    'setting_id': params.setting_id,
                    'agent_type': params.agent_type,
                    'agent_stop_type': params.agent_stop_type,
                    'agent_stop_kwargs': params.agent_stop_kwargs,
                    'model_name': params.model_name,
                    'model_provider': params.model_provider,
                    'temperature': params.temperature,
                    'max_tokens': params.max_tokens,
                    'random_seed': params.random_seed,
                    'dataset_type': params.dataset_type,
                    'dataset_dir': params.dataset_dir,
                    'num_samples': params.num_samples,
                    'start_index': params.start_index,
                    'end_index': params.end_index,
                    'search_engine_type': params.search_engine_type,
                    'chromadb_base_path': params.chromadb_base_path,
                    'collection_name': params.collection_name,
                    'embedding_model': params.embedding_model,
                    'results_per_page': params.results_per_page,
                    'max_documents': params.max_documents,
                    'num_workers': params.num_workers,
                    'retry_attempts': params.retry_attempts,
                    'retry_delay': params.retry_delay,
                    'chunk_size': params.chunk_size,
                    'enable_continual_evaluation': params.enable_continual_evaluation,
                    'deep_research_config': params.deep_research_config,
                    'metadata': params.metadata
                }
                for params in self.evaluation_params
            ]
        }
