import os
import time
import json
from pathlib import Path
from dataclasses import dataclass
from functools import cached_property
from typing import Dict
from copy import deepcopy

import numpy as np
import torch
from dataclasses_json import dataclass_json
from gymnasium.spaces import Box, Discrete, Dict as GymDict
from pettingzoo import ParallelEnv
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from tame.agents.simple_ppo import Agent as SimplePPO
from tame.hierarchy.base_agent import BaseAgent
from tame.hierarchy.hierarchy import Hierarchy, LevelConfig, AgentConfig
from tame.utils.config import ArgsInterface
from tame.utils.utils import filter_unexpected_fields


@filter_unexpected_fields
@dataclass_json
@dataclass
class Args(ArgsInterface):
    # === ОБЩИЕ ПАРАМЕТРЫ ЭКСПЕРИМЕНТА ===
    exp_name: str = os.path.basename(__file__).rstrip(".py")
    seed: int | None = 1
    torch_deterministic: bool = True
    verbose: bool = False
    cuda: int = 0
    save_model: bool = True
    save_all_trace: bool = False
    total_timesteps: int = 500000

    # === ИЕРАРХИЧЕСКИЕ ПАРАМЕТРЫ ===
    freq_bottom: int = 1
    freq_mid: int = 3
    freq_top: int = 3
    learn_comm: bool = True
    learn_proxy: bool = False
    comm_size: int = 16
    ae_epochs: int = 100

    # === ОБЩИЕ PPO ПАРАМЕТРЫ (по умолчанию для всех уровней) ===
    learning_rate: float = 0.001
    gamma: float = 0.99
    anneal_lr: bool = True
    gae_lambda: float = 0.95
    batch_size: int = 2048
    num_minibatches: int = 8
    update_epochs: int = 4
    norm_adv: bool = True
    clip_coef: float = 0.1
    clip_vloss: bool = True
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    target_kl: float | None = 0.015

    # === СПЕЦИФИЧНЫЕ ПАРАМЕТРЫ ДЛЯ УРОВНЕЙ ===
    # Bottom level overrides (если None, используются общие)
    bottom_learning_rate: float | None = None
    bottom_batch_size: int | None = None
    bottom_clip_coef: float | None = None
    bottom_ent_coef: float | None = None

    # Middle level overrides
    middle_learning_rate: float | None = None
    middle_batch_size: int | None = None
    middle_clip_coef: float | None = None
    middle_ent_coef: float | None = None

    # Top level overrides
    top_learning_rate: float | None = None
    top_batch_size: int | None = None
    top_clip_coef: float | None = None
    top_ent_coef: float | None = None

    # === PHI/PSI ПАРАМЕТРЫ ===
    # Phi function overrides
    phi_learning_rate: float | None = None
    phi_batch_size: int | None = None
    phi_clip_coef: float | None = None

    # Psi function overrides
    psi_learning_rate: float | None = None
    psi_batch_size: int | None = None
    psi_clip_coef: float | None = None

    @cached_property
    def minibatch_size(self) -> int:
        return self.batch_size // self.num_minibatches

    def get_args_for_level(self, level: str) -> 'Args':
        """Создает Args для конкретного уровня с учетом overrides."""
        args = deepcopy(self)

        # Применяем overrides для конкретного уровня
        level_overrides = {
            'bottom': {
                'learning_rate': self.bottom_learning_rate,
                'batch_size': self.bottom_batch_size,
                'clip_coef': self.bottom_clip_coef,
                'ent_coef': self.bottom_ent_coef,
            },
            'middle': {
                'learning_rate': self.middle_learning_rate,
                'batch_size': self.middle_batch_size,
                'clip_coef': self.middle_clip_coef,
                'ent_coef': self.middle_ent_coef,
            },
            'top': {
                'learning_rate': self.top_learning_rate,
                'batch_size': self.top_batch_size,
                'clip_coef': self.top_clip_coef,
                'ent_coef': self.top_ent_coef,
                # Top уровень не использует comm/proxy
                'learn_comm': False,
                'learn_proxy': False,
            }
        }

        if level in level_overrides:
            for param, value in level_overrides[level].items():
                if value is not None:
                    setattr(args, param, value)

        return args

    def get_phi_args(self) -> 'Args':
        """Создает Args для phi функций."""
        args = deepcopy(self)

        # Применяем phi overrides
        if self.phi_learning_rate is not None:
            args.learning_rate = self.phi_learning_rate
        if self.phi_batch_size is not None:
            args.batch_size = self.phi_batch_size
        if self.phi_clip_coef is not None:
            args.clip_coef = self.phi_clip_coef

        return args

    def get_psi_args(self) -> 'Args':
        """Создает Args для psi функций."""
        args = deepcopy(self)

        # Применяем psi overrides
        if self.psi_learning_rate is not None:
            args.learning_rate = self.psi_learning_rate
        if self.psi_batch_size is not None:
            args.batch_size = self.psi_batch_size
        if self.psi_clip_coef is not None:
            args.clip_coef = self.psi_clip_coef

        return args


class PassthroughAgent:
    """Простой агент-прокси, который передает данные без изменений."""

    def __init__(self, name: str, observation_space, action_space,
                 reward_len: int,
                 directives_space=None, communication_space=None, device=None,
                 **kwargs):
        self.name = name
        self.observation_space = observation_space
        self.action_space = action_space
        self.reward_len = reward_len
        self.directives_space = directives_space
        self.communication_space = communication_space
        self.device = device

    def act(self, observation: np.ndarray,
            directive: np.ndarray | int | None = None) -> np.ndarray:
        """Просто возвращаем директиву как есть."""
        if directive is not None:
            if isinstance(directive, (int, np.integer)):
                return np.array([float(directive)])
            return np.atleast_1d(directive)
        return np.array([0])  # Default action

    def act_train(self, observation: np.ndarray,
                  directive: np.ndarray | int | None,
                  global_step: int) -> np.ndarray:
        return self.act(observation, directive)

    def store(self, state, action=None, reward=None, done=None):
        """Ничего не сохраняем - это прокси."""
        pass

    def update_step(self, global_step: int, writer):
        """Ничего не обновляем."""
        pass

    def target_reward(self, reward: np.ndarray) -> float:
        """Просто суммируем награды."""
        if len(reward) == 0:
            return 0.0
        return float(reward.sum())

    def comm(self, observation: np.ndarray,
             reward_vector: np.ndarray | None = None) -> np.ndarray:
        """Передаем observation без изменений."""
        return observation

    def comm_train(self, observation: np.ndarray,
                   reward_vector: np.ndarray | None = None,
                   global_step: int = 0) -> np.ndarray:
        return self.comm(observation, reward_vector)

    def proxy_reward(self, observation: np.ndarray,
                     reward: np.ndarray) -> float:
        """Передаем сумму наград без изменений."""
        if len(reward) == 0:
            return 0.0
        return float(reward.sum())

    def proxy_reward_train(self, observation: np.ndarray, reward: np.ndarray,
                           global_step: int = 0) -> float:
        return self.proxy_reward(observation, reward)

    def seed(self, seed):
        pass

    def save_agent(self, save_path, name="trained_model"):
        pass

    def load_agent(self, load_path, name="trained_model"):
        return True


class SimplePPO3StrategicHierarchy(Hierarchy):
    """
    Специализированная иерархия для Simple PPO3 Strategic архитектуры.

    Создает гибридную структуру где top агент получает информацию как от
    координирующего middle агента, так и от индивидуальных прокси агентов.

    Structure:
        top_ppo
        ├── middle_ppo (координирует всех bottom агентов)
        │   ├── agent_0 → agent_0 (env)
        │   ├── agent_1 → agent_1 (env)
        │   ├── agent_2 → agent_2 (env)
        │   └── agent_3 → agent_3 (env)
        ├── proxy_agent_0 → agent_0 (прямая связь)
        ├── proxy_agent_1 → agent_1 (прямая связь)
        ├── proxy_agent_2 → agent_2 (прямая связь)
        └── proxy_agent_3 → agent_3 (прямая связь)
    """

    def __init__(self, env: ParallelEnv, args: Args):
        super().__init__()
        self.env = env
        self.args = args

        # Setup device
        self.device = torch.device("cpu")
        if args.cuda >= 0:
            self.device = torch.device(
                f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")

        # Build the hierarchy structure
        self._build_hierarchy()

        if args.verbose:
            print("Instantiated SimplePPO3 Strategic hierarchy:")
            self.print_tree()

    def _build_hierarchy(self):
        """Строит трехуровневую стратегическую иерархию."""
        num_env_agents = len(self.env.possible_agents)

        # Define structure
        bottom_agent_names = [f"agent_{i}" for i in range(num_env_agents)]
        middle_names = ["middle_ppo"]  # Только один middle агент
        proxy_names = [f"proxy_agent_{i}" for i in range(num_env_agents)]
        top_agent_name = "top_ppo"

        # Define links
        bottom_env_links = {agent_name: [agent_name] for agent_name in
                            bottom_agent_names}

        # Middle агент связан со всеми bottom агентами
        mid_bottom_links = {
            middle_names[0]: bottom_agent_names,
        }
        # Каждый прокси связан с одним bottom агентом
        for i, proxy_name in enumerate(proxy_names):
            mid_bottom_links[proxy_name] = [bottom_agent_names[i]]

        # Top агент связан со всеми middle агентами (включая прокси)
        top_mid_links = {top_agent_name: middle_names + proxy_names}

        # Build levels
        self._build_bottom_level(bottom_agent_names, bottom_env_links,
                                 mid_bottom_links)
        self._build_middle_level(middle_names, proxy_names, mid_bottom_links,
                                 top_mid_links)
        self._build_top_level(top_agent_name, top_mid_links,
                              middle_names + proxy_names)

    def _build_bottom_level(self, bottom_agent_names, bottom_env_links,
                            mid_bottom_links):
        """Строит нижний уровень иерархии."""
        bottom_agents = []
        bottom_args = self.args.get_args_for_level('bottom')

        for name in bottom_agent_names:
            env_obs_space = self.env.observation_space(name)
            env_action_space = self.env.action_space(name)

            comm_space = None
            if self.args.learn_comm:
                comm_space = Box(-np.inf, np.inf, shape=[self.args.comm_size])

            # Подготавливаем kwargs для SimplePPO
            agent_kwargs = {
                "args": bottom_args,
                "phi_args": self.args.get_phi_args() if self.args.learn_comm else None,
                "psi_args": self.args.get_psi_args() if self.args.learn_proxy else None,
            }

            bottom_agents.append(
                AgentConfig(
                    name=name,
                    observation_space=env_obs_space,
                    action_space=env_action_space,
                    reward_len=1,
                    directives_space=Discrete(5),
                    communication_space=comm_space,
                    agent_class=SimplePPO,
                    agent_kwargs=agent_kwargs,
                    device=self.device,
                )
            )

        # Action space configuration
        bottom_lev_action_space = {}
        for hl_agent, agent_names in mid_bottom_links.items():
            agent_spaces = {agent_name: Discrete(5) for agent_name in
                            agent_names}
            bottom_lev_action_space[hl_agent] = GymDict(agent_spaces)
        bottom_lev_action_space = GymDict(bottom_lev_action_space)

        self.add_level_config(
            LevelConfig(
                name="bottom",
                agents=bottom_agents,
                uplinks=mid_bottom_links,
                downlinks=bottom_env_links,
                action_frequency=self.args.freq_bottom,
                trace_type="full" if self.args.save_all_trace else "reward",
                concat_obs=True,
                action_space=bottom_lev_action_space,
                env=self.env,
            )
        )

    def _build_middle_level(self, middle_names, proxy_names, mid_bottom_links,
                            top_mid_links):
        """Строит средний уровень иерархии (middle + proxy агенты)."""
        middle_agents = []
        middle_args = self.args.get_args_for_level('middle')

        # Добавляем обучающийся middle агент
        for agent_name in middle_names:
            subordinate_agents = mid_bottom_links[agent_name]

            if self.args.learn_comm:
                obs_size = self.args.comm_size * len(subordinate_agents)
            else:
                obs_size = sum(self.env.observation_space(sub_agent).shape[0]
                               for sub_agent in subordinate_agents)

            obs_space = Box(-np.inf, np.inf, shape=[obs_size])

            comm_space = None
            if self.args.learn_comm:
                comm_space = Box(-np.inf, np.inf, shape=[self.args.comm_size])

            # Подготавливаем kwargs для SimplePPO
            agent_kwargs = {
                "args": middle_args,
                "phi_args": self.args.get_phi_args() if self.args.learn_comm else None,
                "psi_args": self.args.get_psi_args() if self.args.learn_proxy else None,
            }

            middle_agents.append(
                AgentConfig(
                    name=agent_name,
                    observation_space=obs_space,
                    action_space=Discrete(5),
                    reward_len=len(subordinate_agents),
                    directives_space=Discrete(5),
                    communication_space=comm_space,
                    agent_class=SimplePPO,
                    agent_kwargs=agent_kwargs,
                    device=self.device,
                )
            )

        # Добавляем прокси агентов
        for proxy_name in proxy_names:
            subordinate_agents = mid_bottom_links[proxy_name]  # Всегда 1 агент

            if self.args.learn_comm:
                obs_size = self.args.comm_size * len(subordinate_agents)
            else:
                obs_size = sum(self.env.observation_space(sub_agent).shape[0]
                               for sub_agent in subordinate_agents)

            obs_space = Box(-np.inf, np.inf, shape=[obs_size])

            # Для прокси агентов communication space такой же как observation space
            comm_space = obs_space if not self.args.learn_comm else Box(
                -np.inf, np.inf, shape=[obs_size])

            middle_agents.append(
                AgentConfig(
                    name=proxy_name,
                    observation_space=obs_space,
                    action_space=Discrete(5),
                    reward_len=len(subordinate_agents),
                    directives_space=Discrete(5),
                    communication_space=comm_space,
                    agent_class=PassthroughAgent,  # Используем прокси агента
                    agent_kwargs={},
                    device=self.device,
                )
            )

        # Action space configuration
        mid_lev_action_space = {}
        for hl_agent, agent_names in top_mid_links.items():
            agent_spaces = {agent_name: Discrete(5) for agent_name in
                            agent_names}
            mid_lev_action_space[hl_agent] = GymDict(agent_spaces)
        mid_lev_action_space = GymDict(mid_lev_action_space)

        self.add_level_config(
            LevelConfig(
                name="middle",
                agents=middle_agents,
                uplinks=top_mid_links,
                downlinks=mid_bottom_links,
                action_frequency=self.args.freq_mid,
                trace_type="full" if self.args.save_all_trace else "reward",
                concat_obs=True,
                action_space=mid_lev_action_space,
                env=self.levels[-1],
            )
        )

    def _build_top_level(self, top_agent_name, top_mid_links,
                         all_middle_agents):
        """Строит верхний уровень иерархии."""
        if self.args.learn_comm:
            obs_size = self.args.comm_size * len(all_middle_agents)
        else:
            # Считаем размер наблюдений от всех middle агентов (включая прокси)
            obs_size = 0
            for middle_name in all_middle_agents:
                # Находим middle агента в предыдущем уровне
                for agent_config in self.levels[-1].agents.values():
                    if hasattr(agent_config,
                               'name') and agent_config.name == middle_name:
                        obs_size += agent_config.observation_space.shape[0]
                        break

        # Получаем специальные аргументы для top агента
        top_args = self.args.get_args_for_level('top')

        top_agent = AgentConfig(
            name=top_agent_name,
            observation_space=Box(-np.inf, np.inf, shape=[obs_size]),
            action_space=Discrete(5),
            reward_len=len(all_middle_agents),
            # Награда от каждого middle агента (включая прокси)
            directives_space=None,
            communication_space=None,
            agent_class=SimplePPO,
            agent_kwargs={"args": top_args},
            device=self.device
        )

        self.add_level_config(
            LevelConfig(
                name="top",
                agents=[top_agent],
                uplinks=None,
                downlinks=top_mid_links,
                action_frequency=self.args.freq_top,
                trace_type="full" if self.args.save_all_trace else "reward",
                env=self.levels[-1],
            )
        )

    def train(
            self,
            env: ParallelEnv | None = None,
            log_path: Path | str | None = None,
            run_name: str | None = None,
    ):
        """Обучает иерархию с настройкой логирования."""
        if env is None:
            env = self.env
        self.connect(env)

        # Prepare for logging
        if run_name is None:
            run_name = f"{self.args.exp_name}__{self.args.seed}__{int(time.time())}"

        if log_path is None:
            log_path = Path("runs")
        else:
            log_path = Path(log_path)

        if not (log_path / run_name).exists():
            os.makedirs(log_path / run_name)

        args_dict = self.args.to_dict()
        with open(log_path / run_name / "params.json", "w") as f:
            json.dump(args_dict, f, indent=4)

        interface_level_name = {"name": self.interface_level.name}
        with open(log_path / run_name / "interface_level.json", "w") as f:
            json.dump(interface_level_name, f, indent=4)

        save_path = log_path / run_name / "training"

        writer = SummaryWriter(save_path / "tboard")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s"
            % (
                "\n".join(
                    [f"|{key}|{value}|" for key, value in
                     vars(self.args).items()]
                )
            ),
        )

        self.set_logger(logger=writer, save_path=save_path)

        # Training loop
        done = False
        episode = 0
        self.reset()

        with tqdm(total=self.args.total_timesteps,
                  desc="Training step:") as pbar:
            while self.levels[0].level_ts < self.args.total_timesteps:
                if done:
                    done = False
                    self.reset()
                    episode += 1

                _, reward, terminated, truncated, _ = self.step(action=None)

                if any(terminated.values()) or any(truncated.values()):
                    done = True

                pbar.update(self.levels[0].level_ts - pbar.n)

        # Final cleanup
        self.reset()
        writer.close()


class Agent(BaseAgent):
    """
    Simple PPO3 Strategic agent implemented via specialized Hierarchy.

    Использует SimplePPO3StrategicHierarchy для создания гибридной структуры
    с координирующим middle агентом и прямыми прокси связями.
    """

    def __init__(self, env: ParallelEnv, args: None | Args = None) -> None:
        if args is None:
            self.args: Args = Args()
        else:
            self.args: Args = args

        self.env = env
        self.seed(self.args.seed)

        # Создаем специализированную иерархию
        self.hierarchy = SimplePPO3StrategicHierarchy(env, self.args)

    def seed(self, seed):
        """Sets random seeds for reproducibility."""
        import random
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = self.args.torch_deterministic

    def train(
            self,
            env: ParallelEnv | None = None,
            log_path: Path | str | None = None,
            run_name: str | None = None,
    ):
        """Train the hierarchical agent."""
        self.hierarchy.train(env, log_path, run_name)

    def save_agent(self, save_path: str | Path, name: None | str = None):
        """Save the agent's hierarchy to disk."""
        self.hierarchy.save(Path(save_path))

    def load_agent(self, load_path: Path | str,
                   name: str = "trained_model") -> bool:
        """Load a trained agent from disk."""
        loaded = self.hierarchy.load(Path(load_path))
        if loaded:
            print("Agents loaded successfully")
        else:
            raise RuntimeError("Could not load agents!")
        return loaded

    def act(self, observation: Dict[str, np.ndarray]) -> dict:
        """Takes an observation and returns the action decided by the hierarchy."""
        return self.hierarchy.act(observation)