import os
import time
import json
from pathlib import Path
from dataclasses import dataclass
from functools import cached_property
from typing import Dict
from copy import deepcopy

import numpy as np
import torch
from dataclasses_json import dataclass_json
from gymnasium.spaces import Box, Discrete, Dict as GymDict
from pettingzoo import ParallelEnv
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from tame.agents.simple_ppo import Agent as SimplePPO
from tame.hierarchy.base_agent import BaseAgent
from tame.hierarchy.hierarchy import Hierarchy, LevelConfig, AgentConfig
from tame.utils.config import ArgsInterface
from tame.utils.utils import filter_unexpected_fields


@filter_unexpected_fields
@dataclass_json
@dataclass
class Args(ArgsInterface):
    # === ОБЩИЕ ПАРАМЕТРЫ ЭКСПЕРИМЕНТА ===
    exp_name: str = os.path.basename(__file__).rstrip(".py")
    seed: int | None = 1
    torch_deterministic: bool = True
    verbose: bool = False
    cuda: int = 0
    save_model: bool = True
    save_all_trace: bool = False
    total_timesteps: int = 500000

    # === ИЕРАРХИЧЕСКИЕ ПАРАМЕТРЫ ===
    freq_bottom: int = 1
    freq_mid: int = 1
    freq_top: int = 1
    learn_comm: bool = True
    learn_proxy: bool = False
    comm_size: int = 16
    ae_epochs: int = 100

    # === ОБЩИЕ PPO ПАРАМЕТРЫ (по умолчанию для всех уровней) ===
    learning_rate: float = 0.001
    gamma: float = 0.99
    anneal_lr: bool = True
    gae_lambda: float = 0.95
    batch_size: int = 2048
    num_minibatches: int = 8
    update_epochs: int = 4
    norm_adv: bool = True
    clip_coef: float = 0.1
    clip_vloss: bool = True
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    target_kl: float | None = 0.015

    # === СПЕЦИФИЧНЫЕ ПАРАМЕТРЫ ДЛЯ УРОВНЕЙ ===
    # Bottom level overrides (если None, используются общие)
    bottom_learning_rate: float | None = None
    bottom_batch_size: int | None = None
    bottom_clip_coef: float | None = None
    bottom_ent_coef: float | None = None

    # Middle level overrides
    middle_learning_rate: float | None = None
    middle_batch_size: int | None = None
    middle_clip_coef: float | None = None
    middle_ent_coef: float | None = None

    # Top level overrides
    top_learning_rate: float | None = None
    top_batch_size: int | None = None
    top_clip_coef: float | None = None
    top_ent_coef: float | None = None

    # === PHI/PSI ПАРАМЕТРЫ ===
    # Phi function overrides
    phi_learning_rate: float | None = None
    phi_batch_size: int | None = None
    phi_clip_coef: float | None = None

    # Psi function overrides
    psi_learning_rate: float | None = None
    psi_batch_size: int | None = None
    psi_clip_coef: float | None = None

    @cached_property
    def minibatch_size(self) -> int:
        return self.batch_size // self.num_minibatches

    def get_args_for_level(self, level: str) -> 'Args':
        """Создает Args для конкретного уровня с учетом overrides."""
        from copy import deepcopy
        args = deepcopy(self)

        # Применяем overrides для конкретного уровня
        level_overrides = {
            'bottom': {
                'learning_rate': self.bottom_learning_rate,
                'batch_size': self.bottom_batch_size,
                'clip_coef': self.bottom_clip_coef,
                'ent_coef': self.bottom_ent_coef,
            },
            'middle': {
                'learning_rate': self.middle_learning_rate,
                'batch_size': self.middle_batch_size,
                'clip_coef': self.middle_clip_coef,
                'ent_coef': self.middle_ent_coef,
            },
            'top': {
                'learning_rate': self.top_learning_rate,
                'batch_size': self.top_batch_size,
                'clip_coef': self.top_clip_coef,
                'ent_coef': self.top_ent_coef,
                # Top уровень не использует comm/proxy
                'learn_comm': False,
                'learn_proxy': False,
            }
        }

        if level in level_overrides:
            for param, value in level_overrides[level].items():
                if value is not None:
                    setattr(args, param, value)

        return args

    def get_phi_args(self) -> 'Args':
        """Создает Args для phi функций."""
        from copy import deepcopy
        args = deepcopy(self)

        # Применяем phi overrides
        if self.phi_learning_rate is not None:
            args.learning_rate = self.phi_learning_rate
        if self.phi_batch_size is not None:
            args.batch_size = self.phi_batch_size
        if self.phi_clip_coef is not None:
            args.clip_coef = self.phi_clip_coef

        return args

    def get_psi_args(self) -> 'Args':
        """Создает Args для psi функций."""
        from copy import deepcopy
        args = deepcopy(self)

        # Применяем psi overrides
        if self.psi_learning_rate is not None:
            args.learning_rate = self.psi_learning_rate
        if self.psi_batch_size is not None:
            args.batch_size = self.psi_batch_size
        if self.psi_clip_coef is not None:
            args.clip_coef = self.psi_clip_coef

        return args


class SimplePPO3Hierarchy(Hierarchy):
    """
    Специализированная иерархия для Simple PPO3 архитектуры.

    Наследуется от базового класса Hierarchy и автоматически конфигурирует
    трехуровневую структуру с SimplePPO агентами.
    """

    def __init__(self, env: ParallelEnv, args: Args):
        super().__init__()
        self.env = env
        self.args = args

        # Setup device
        self.device = torch.device("cpu")
        if args.cuda >= 0:
            self.device = torch.device(
                f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")

        # Build the hierarchy structure
        self._build_hierarchy()

        if args.verbose:
            print("Instantiated SimplePPO3 hierarchy:")
            self.print_tree()

    def _build_hierarchy(self):
        """Строит трехуровневую иерархию SimplePPO3."""
        num_env_agents = len(self.env.possible_agents)

        # Define structure
        bottom_agent_names = [f"agent_{i}" for i in range(num_env_agents)]
        middle_names = ["middle_ppo_1", "middle_ppo_2"]
        top_agent_name = "top_ppo"

        # Define links
        bottom_env_links = {agent_name: [agent_name] for agent_name in
                            bottom_agent_names}
        mid_bottom_links = {
            middle_names[0]: bottom_agent_names[:int(num_env_agents / 2)],
            middle_names[1]: bottom_agent_names[int(num_env_agents / 2):],
        }
        top_mid_links = {top_agent_name: middle_names}

        # Build levels
        self._build_bottom_level(bottom_agent_names, bottom_env_links,
                                 mid_bottom_links)
        self._build_middle_level(middle_names, mid_bottom_links, top_mid_links)
        self._build_top_level(top_agent_name, top_mid_links, middle_names)

    def _build_bottom_level(self, bottom_agent_names, bottom_env_links,
                            mid_bottom_links):
        """Строит нижний уровень иерархии."""
        bottom_agents = []

        for name in bottom_agent_names:
            env_obs_space = self.env.observation_space(name)
            env_action_space = self.env.action_space(name)

            comm_space = None
            if self.args.learn_comm:
                comm_space = Box(-np.inf, np.inf, shape=[self.args.comm_size])

            bottom_agents.append(
                AgentConfig(
                    name=name,
                    observation_space=env_obs_space,
                    action_space=env_action_space,
                    reward_len=1,
                    directives_space=Discrete(5),
                    communication_space=comm_space,
                    agent_class=SimplePPO,
                    agent_kwargs={"args": self.args},
                    device=self.device,
                )
            )

        # Action space configuration
        bottom_lev_action_space = {}
        for hl_agent, agent_names in mid_bottom_links.items():
            agent_spaces = {agent_name: Discrete(5) for agent_name in
                            agent_names}
            bottom_lev_action_space[hl_agent] = GymDict(agent_spaces)
        bottom_lev_action_space = GymDict(bottom_lev_action_space)

        self.add_level_config(
            LevelConfig(
                name="bottom",
                agents=bottom_agents,
                uplinks=mid_bottom_links,
                downlinks=bottom_env_links,
                action_frequency=self.args.freq_bottom,
                trace_type="full" if self.args.save_all_trace else "reward",
                concat_obs=True,
                action_space=bottom_lev_action_space,
                env=self.env,
            )
        )

    def _build_middle_level(self, middle_names, mid_bottom_links,
                            top_mid_links):
        """Строит средний уровень иерархии."""
        middle_agents = []

        for agent_name in middle_names:
            subordinate_agents = mid_bottom_links[agent_name]

            if self.args.learn_comm:
                obs_size = self.args.comm_size * len(subordinate_agents)
            else:
                obs_size = sum(self.env.observation_space(sub_agent).shape[0]
                               for sub_agent in subordinate_agents)

            obs_space = Box(-np.inf, np.inf, shape=[obs_size])

            comm_space = None
            if self.args.learn_comm:
                comm_space = Box(-np.inf, np.inf, shape=[self.args.comm_size])

            middle_agents.append(
                AgentConfig(
                    name=agent_name,
                    observation_space=obs_space,
                    action_space=Discrete(5),
                    reward_len=len(subordinate_agents),
                    directives_space=Discrete(5),
                    communication_space=comm_space,
                    agent_class=SimplePPO,
                    agent_kwargs={"args": self.args},
                    device=self.device,
                )
            )

        # Action space configuration
        mid_lev_action_space = {}
        for hl_agent, agent_names in top_mid_links.items():
            agent_spaces = {agent_name: Discrete(5) for agent_name in
                            agent_names}
            mid_lev_action_space[hl_agent] = GymDict(agent_spaces)
        mid_lev_action_space = GymDict(mid_lev_action_space)

        self.add_level_config(
            LevelConfig(
                name="middle",
                agents=middle_agents,
                uplinks=top_mid_links,
                downlinks=mid_bottom_links,
                action_frequency=self.args.freq_mid,
                trace_type="full" if self.args.save_all_trace else "reward",
                concat_obs=True,
                action_space=mid_lev_action_space,
                env=self.levels[-1],
            )
        )

    def _build_top_level(self, top_agent_name, top_mid_links, middle_names):
        """Строит верхний уровень иерархии."""
        if self.args.learn_comm:
            obs_size = self.args.comm_size * len(middle_names)
        else:
            # Считаем размер наблюдений от middle агентов
            obs_size = 0
            for middle_name in middle_names:
                # Находим middle агента в предыдущем уровне
                for agent_config in self.levels[-1].agents.values():
                    if hasattr(agent_config,
                               'name') and agent_config.name == middle_name:
                        obs_size += agent_config.observation_space.shape[0]
                        break

        # Создаем специальные аргументы для top агента
        top_args = deepcopy(self.args)
        top_args.learn_comm = False
        top_args.learn_proxy = False

        top_agent = AgentConfig(
            name=top_agent_name,
            observation_space=Box(-np.inf, np.inf, shape=[obs_size]),
            action_space=Discrete(5),
            reward_len=len(middle_names),
            directives_space=None,
            communication_space=None,
            agent_class=SimplePPO,
            agent_kwargs={"args": top_args},
            device=self.device
        )

        self.add_level_config(
            LevelConfig(
                name="top",
                agents=[top_agent],
                uplinks=None,
                downlinks=top_mid_links,
                action_frequency=self.args.freq_top,
                trace_type="full" if self.args.save_all_trace else "reward",
                env=self.levels[-1],
            )
        )

    def train(
            self,
            env: ParallelEnv | None = None,
            log_path: Path | str | None = None,
            run_name: str | None = None,
    ):
        """Обучает иерархию с настройкой логирования."""
        if env is None:
            env = self.env
        self.connect(env)

        # Prepare for logging
        if run_name is None:
            run_name = f"{self.args.exp_name}__{self.args.seed}__{int(time.time())}"

        if log_path is None:
            log_path = Path("runs")
        else:
            log_path = Path(log_path)

        if not (log_path / run_name).exists():
            os.makedirs(log_path / run_name)

        args_dict = self.args.to_dict()
        with open(log_path / run_name / "params.json", "w") as f:
            json.dump(args_dict, f, indent=4)

        interface_level_name = {"name": self.interface_level.name}
        with open(log_path / run_name / "interface_level.json", "w") as f:
            json.dump(interface_level_name, f, indent=4)

        save_path = log_path / run_name / "training"

        writer = SummaryWriter(save_path / "tboard")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s"
            % (
                "\n".join(
                    [f"|{key}|{value}|" for key, value in
                     vars(self.args).items()]
                )
            ),
        )

        self.set_logger(logger=writer, save_path=save_path)

        # Training loop
        done = False
        episode = 0
        self.reset()

        with tqdm(total=self.args.total_timesteps,
                  desc="Training step:") as pbar:
            while self.levels[0].level_ts < self.args.total_timesteps:
                if done:
                    done = False
                    self.reset()
                    episode += 1

                _, reward, terminated, truncated, _ = self.step(action=None)

                if any(terminated.values()) or any(truncated.values()):
                    done = True

                pbar.update(self.levels[0].level_ts - pbar.n)

        # Final cleanup
        self.reset()
        writer.close()


class Agent(BaseAgent):
    """
    Simple PPO3 agent implemented via specialized Hierarchy.

    Использует SimplePPO3Hierarchy для создания трехуровневой структуры
    с автоматической конфигурацией.
    """

    def __init__(self, env: ParallelEnv, args: None | Args = None) -> None:
        if args is None:
            self.args: Args = Args()
        else:
            self.args: Args = args

        self.env = env
        self.seed(self.args.seed)

        # Создаем специализированную иерархию
        self.hierarchy = SimplePPO3Hierarchy(env, self.args)

    def seed(self, seed):
        """Sets random seeds for reproducibility."""
        import random
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = self.args.torch_deterministic

    def train(
            self,
            env: ParallelEnv | None = None,
            log_path: Path | str | None = None,
            run_name: str | None = None,
    ):
        """Train the hierarchical agent."""
        self.hierarchy.train(env, log_path, run_name)

    def save_agent(self, save_path: str | Path, name: None | str = None):
        """Save the agent's hierarchy to disk."""
        self.hierarchy.save(Path(save_path))

    def load_agent(self, load_path: Path | str,
                   name: str = "trained_model") -> bool:
        """Load a trained agent from disk."""
        loaded = self.hierarchy.load(Path(load_path))
        if loaded:
            print("Agents loaded successfully")
        else:
            raise RuntimeError("Could not load agents!")
        return loaded

    def act(self, observation: Dict[str, np.ndarray]) -> dict:
        """Takes an observation and returns the action decided by the hierarchy."""
        return self.hierarchy.act(observation)