import logging
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, List, Tuple

from pettingllms.multi_agent_env.base.env import MultiAgentsEnvironment
from pettingllms.multi_agent_env.plan_path.utils import (
    load_plan_path_problem_batch,
)
from pettingllms.multi_agent_env.plan_path.env_state import (
    PlanPathGridEnvState, 
    EightQueensEnvState, 
    BlocksworldEnvState, 
    Sudoku4x4EnvState,
    get_state_class_by_benchmark
)

logger = logging.getLogger(__name__)




# -------------------------
# State 与 Env 定义
# -------------------------

class PlanPathEnv(MultiAgentsEnvironment):
    """
    路径规划（plan-path）单智能体环境：
    - agent 产出方案（例如坐标序列/动作序列/可解析文本）
    - 环境用 state.worker 做可行性与最优性评估
    """

    def __init__(
        self,
        env_idx: int,
        rollout_idx: int,
        max_turns: int,
        config: dict | None = None,
    ):
        super().__init__(env_idx=env_idx, rollout_idx=rollout_idx, max_turns=max_turns, config=config)
        self.state = None  # 将在 PlanPathEnvBatch 中正确初始化
    def reset(self):
        """重置回合内的中间产物与评估标记（不改动题面与worker）。"""
        if self.state is not None:
            # 根据 EnvStateBase 的定义重置基础属性
            self.state.tool_action = []  # List[str] 类型
            self.state.tool_code = ""    # str 类型
            self.state.tool_execution_output = ""  # str 类型
            self.state.plan_action = []  # List[str] 类型
            
            # 重置其他可能的状态属性（如果存在的话）
            if hasattr(self.state, 'reasoning_generated_plan'):
                self.state.reasoning_generated_plan = None
            if hasattr(self.state, 'code_generated_plan'):
                self.state.code_generated_plan = None
            if hasattr(self.state, 'reasoning_extracted_path'):
                self.state.reasoning_extracted_path = None
            if hasattr(self.state, 'code_extracted_path'):
                self.state.code_extracted_path = None
            if hasattr(self.state, 'reasoning_is_feasible'):
                self.state.reasoning_is_feasible = None
            if hasattr(self.state, 'code_is_feasible'):
                self.state.code_is_feasible = None
            if hasattr(self.state, 'reasoning_is_optimal'):
                self.state.reasoning_is_optimal = None
            if hasattr(self.state, 'code_is_optimal'):
                self.state.code_is_optimal = None
            if hasattr(self.state, 'code_reasoning_aligned'):
                self.state.code_reasoning_aligned = None



# -------------------------
# Batch 构建器
# -------------------------
class PlanPathEnvBatch:
    """
    与 MathTestEnvBatch 对齐的批量构建器：
      - 从 load_plan_path_problem_batch 取题
      - 为每个样本建立初始 State（包含 worker 与 worker_text）
      - 复制 State 给多个 sample（samples>1 时）
    """

    def __init__(
        self,
        env_idx_list: List[int],
        env_indices: List[int],
        rollout_idx_list: List[int],
        samples: int,
        max_turns: int,
        config: dict,
        mode: str = "train",
        *,
        env_workers: None = None,  # 可选：外部传入 worker
    ):
        
        
        self.env_list: List[PlanPathEnv] = []

        if mode == "validate":
            env_idx_list = range(100)
            rollout_idx_list = range(100)
            samples = 1

        benchmark_name=getattr(config,"benchmark") if hasattr(config,"benchmark") else "PlanPath"
        
        
        
        for i, prob in enumerate(env_idx_list):

            state_class = get_state_class_by_benchmark(benchmark_name)
            
            # 根据不同的benchmark创建相应的状态实例
            if benchmark_name == "PlanPath":
                # 使用环境索引作为seed，确保相同索引生成相同环境
                seed = env_indices[i] if i < len(env_indices) else i
                state = state_class(seed=seed, config=config)
            elif benchmark_name == "EightQueens":
                state = state_class(N=prob["N"])
            elif benchmark_name == "Blocksworld":
                state = state_class(
                    init_stacks=prob["init_stacks"],
                    goal_stacks=prob["goal_stacks"]
                )
            elif benchmark_name == "sudoku4x4":
                # 使用环境索引作为seed，确保相同索引生成相同环境
                seed = env_indices[i] if i < len(env_indices) else i
                state = state_class(seed=seed, config=config)
            else:
                raise ValueError(f"Unsupported benchmark: {benchmark_name}")
            # 复制为多个 env 实例
            for s in range(samples):
                env = PlanPathEnv(env_idx=i, rollout_idx=rollout_idx_list[i * samples + s], max_turns=max_turns, config=None)
                # 注意：deepcopy 会复制 worker；如果你希望多个样本共享同一个 worker 实例，可改为浅拷贝并手动赋值
                env.state = copy.deepcopy(state)
                self.env_list.append(env)

        if len(self.env_list) != len(rollout_idx_list):
            raise ValueError(
                f"len(self.env_list)!=len(rollout_idx_list), {len(self.env_list)}!={len(rollout_idx_list)}"
            )
