import os
import random
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any

import gymnasium
import numpy as np
import torch
from dataclasses_json import dataclass_json
from gymnasium.spaces import Box
from torch.utils.tensorboard.writer import SummaryWriter

from tame.agents.base_ppo import PPO
from tame.hierarchy.base_agent import LevelAgent
from tame.utils.config import ArgsInterface
from tame.utils.utils import filter_unexpected_fields


@filter_unexpected_fields
@dataclass_json
@dataclass
class Args(ArgsInterface):
    exp_name: str = os.path.basename(__file__).rstrip(".py")
    seed: int | None = 1
    torch_deterministic: bool = True
    cuda: int = 0
    save_model: bool = True
    total_timesteps: int = 500000
    learning_rate: float = 2.5e-4
    gamma: float = 0.99
    anneal_lr: bool = True
    gae_lambda: float = 0.95
    batch_size: int = 2048 * 1  # It's num_steps * num_envs
    num_minibatches: int = 4
    update_epochs: int = 4
    norm_adv: bool = True
    clip_coef: float = 0.2
    clip_vloss: bool = True
    ent_coef: float = 0.0
    vf_coef: float = 0.5
    save_all_trace: bool = False
    max_grad_norm: float = 0.5
    target_kl: float | None = None
    verbose: bool = True
    learn_comm: bool = False
    learn_proxy: bool = False

    @cached_property
    def minibatch_size(self) -> int:
        return self.batch_size // self.num_minibatches


class Agent(LevelAgent):
    """A simple PPO agent that can handle hierarchical environments.

    It wraps around a basic PPO implementation for single agent control and supports
    learnable communication (phi) and proxy reward (psi) functions.
    """

    def __init__(
            self,
            observation_space: gymnasium.spaces.Box,
            action_space: gymnasium.spaces.Box | gymnasium.spaces.Discrete,
            reward_len: int,
            device: torch.device,
            directives_space: gymnasium.spaces.Box | gymnasium.spaces.Discrete | None = None,
            communication_space: gymnasium.spaces.Box | None = None,
            name: str = "simple_ppo",
            args: None | Args = None,
            phi_args: None | Args = None,
            psi_args: None | Args = None,
            torch_compile: bool = False
    ) -> None:
        if args is None:
            self.args: Args = Args()
        else:
            self.args: Args = args

        if phi_args is None:
            self.phi_args: Args = Args()
        else:
            self.phi_args: Args = phi_args

        if psi_args is None:
            self.psi_args: Args = Args()
        else:
            self.psi_args: Args = psi_args

        self.seed(self.args.seed)

        self.observation_space = observation_space
        self.directives_space = directives_space
        self.action_space = action_space
        self.communication_space = communication_space
        self.reward_len = reward_len

        self.device = device
        self.name = name

        # Calculate main observation space
        obs_size = self.observation_space.shape[0]
        if self.directives_space is not None:
            if isinstance(self.directives_space, Box):
                obs_size += self.directives_space.shape[0]
            else:  # Discrete
                obs_size += 1  # One-hot encoding size for discrete

        self.agent = PPO(
            observation_space=Box(-np.inf, np.inf, shape=[obs_size]),
            action_space=self.action_space,
            args=self.args,
            torch_compile=torch_compile,
            device=self.device,
            name=self.name,
        )
        self.agent.seed(seed=self.args.seed)

        # Initialize phi (communication function) if needed
        self.phi = None
        if (self.communication_space is not None and
                hasattr(self.args, "learn_comm") and self.args.learn_comm):
            phi_state_size = self.observation_space.shape[0] + self.reward_len

            self.phi = PPO(
                observation_space=Box(-np.inf, np.inf, shape=[phi_state_size]),
                action_space=self.communication_space,
                args=self.phi_args,
                torch_compile=torch_compile,
                device=self.device,
                name="phi_" + self.name
            )
            self.phi.seed(seed=self.phi_args.seed)

        # Initialize psi (proxy reward function) if needed
        self.psi = None
        if hasattr(self.args, "learn_proxy") and self.args.learn_proxy:
            psi_state_size = self.observation_space.shape[0] + self.reward_len

            self.psi = PPO(
                observation_space=Box(-np.inf, np.inf, shape=[psi_state_size]),
                action_space=Box(-np.inf, np.inf, shape=[1]),
                args=self.psi_args,
                torch_compile=torch_compile,
                device=self.device,
                name="psi_" + self.name
            )
            self.psi.seed(seed=self.psi_args.seed)

    def seed(self, seed):
        """Sets random seeds for reproducibility."""
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.backends.cudnn.deterministic = self.args.torch_deterministic

    def act(self, observation: np.ndarray,
            directive: np.ndarray | int | None = None) -> np.ndarray:
        """Select actions based on observation and directive."""
        flat_obs = self._prepare_observation(observation, directive)
        return self.agent.act(flat_obs)

    def act_train(
            self, observation: np.ndarray, directive: np.ndarray | int | None,
            global_step: int
    ) -> np.ndarray:
        """Select actions during training phase."""
        flat_obs = self._prepare_observation(observation, directive)
        action = self.agent.act_train(flat_obs, global_step=global_step)
        return action

    def _prepare_observation(self, observation: np.ndarray,
                             directive: np.ndarray | int | None) -> np.ndarray:
        """Prepare observation by concatenating with directive if present."""
        if directive is not None:
            if isinstance(directive, (int, np.integer)):
                # Convert discrete directive to one-hot or scalar
                directive_array = np.array([float(directive)])
            else:
                directive_array = np.atleast_1d(directive)
            return np.concatenate([observation, directive_array])
        else:
            return observation

    def update_step(self, global_step: int, writer: None | SummaryWriter):
        """Update all components (phi, psi, main agent)."""
        if self.phi is not None:
            self.phi.update_step(global_step=global_step, writer=writer)
        if self.psi is not None:
            self.psi.update_step(global_step=global_step, writer=writer)
        self.agent.update_step(global_step=global_step, writer=writer)

    def target_reward(self, reward: np.ndarray) -> float:
        """Aggregate per-child rewards to a single scalar used for training."""
        if len(reward) == 0:
            return 0.0
        return float(reward.sum())

    def store(
            self,
            state: np.ndarray | torch.Tensor,
            action: None | np.ndarray | torch.Tensor = None,
            reward: float | None = None,
            done: bool | None = None,
    ):
        """Store transition in agent's memory."""
        self.agent.store(
            state=state,
            action=np.atleast_1d(action) if action is not None else None,
            reward=reward,
            done=done
        )

    def phi_store(self, state, action, reward, done):
        """Store transition for phi training."""
        if self.phi is None:
            return

        if reward is not None:
            reward = float(reward)

        self.phi.store(
            state=state,
            action=np.atleast_1d(action) if action is not None else None,
            reward=reward,
            done=done,
        )

    def psi_store(self, state, action, reward, done):
        """Store transition for psi training."""
        if self.psi is None:
            return

        if reward is not None:
            reward = float(reward)

        self.psi.store(
            state=state,
            action=np.atleast_1d(action) if action is not None else None,
            reward=reward,
            done=done
        )

    def _save_component(self, component, save_path: Path | str,
                        name: str = "trained_model"):
        """Save a component (agent, phi, or psi)."""
        save_path = Path(save_path)
        model_save_path = save_path / "models" / f"{name}.pth"
        if not model_save_path.parent.exists():
            os.makedirs(model_save_path.parent)
        torch.save(component.actor_critic.state_dict(), model_save_path)
        if self._load_component(component, save_path, name=name):
            if hasattr(self.args, 'verbose') and self.args.verbose:
                print(f"{self.name}: {name} saved to {model_save_path}")
        else:
            print(f"{self.name}: Could not save {name}!")

    def _load_component(self, component, load_path: Path | str,
                        name: str = "trained_model") -> bool:
        """Load a component from disk."""
        load_path = Path(load_path) / "models" / f"{name}.pth"
        if load_path.exists():
            try:
                component.actor_critic.load_state_dict(
                    torch.load(load_path, map_location=self.device))
                return True
            except Exception as e:
                print(f"{self.name}: Could not load {name} from {load_path}")
                print(f"{self.name}: {e}")
                return False
        else:
            print(f"{self.name}: Path {load_path} does not exist.")
            return False

    def save_agent(self, save_path: Path | str, name: str = "trained_model"):
        """Save the main agent model."""
        self._save_component(self.agent, save_path, name)

        # Save phi and psi if they exist
        if self.phi is not None:
            self._save_component(self.phi, save_path, f"phi_{name}")
        if self.psi is not None:
            self._save_component(self.psi, save_path, f"psi_{name}")

    def load_agent(self, load_path: Path | str,
                   name: str = "trained_model") -> bool:
        """Load the agent model(s) from disk."""
        success = self._load_component(self.agent, load_path, name)

        # Load phi and psi if they exist
        if self.phi is not None:
            phi_success = self._load_component(self.phi, load_path,
                                               f"phi_{name}")
            success = success and phi_success

        if self.psi is not None:
            psi_success = self._load_component(self.psi, load_path,
                                               f"psi_{name}")
            success = success and psi_success

        return success

    def comm(self, observation: np.ndarray,
             reward_vector: np.ndarray | None = None) -> np.ndarray:
        """Generate communication representation for higher-level agent (inference mode)."""
        # if self.communication_space is None:
        #     raise ValueError(
        #         f"Agent {self.name} has no communication space defined."
        #     )

        if self.phi is not None:
            if reward_vector is not None:
                phi_input = np.concatenate([observation, reward_vector])
            else:
                phi_input = np.concatenate(
                    [observation, np.zeros(self.reward_len)])

            message = self.phi.act(phi_input)
            return message
        else:
            return observation

    def comm_train(self, observation: np.ndarray,
                   reward_vector: np.ndarray | None = None,
                   global_step: int = 0) -> np.ndarray:
        """Generate communication representation during training (uses act_train for phi)."""
        # if self.communication_space is None:
        #     raise ValueError(
        #         f"Agent {self.name} has no communication space defined."
        #     )

        if self.phi is not None:
            if reward_vector is not None:
                phi_input = np.concatenate([observation, reward_vector])
            else:
                phi_input = np.concatenate(
                    [observation, np.zeros(self.reward_len)])

            message = self.phi.act_train(phi_input, global_step=global_step)
            return message
        else:
            return observation

    def proxy_reward(self, observation: np.ndarray,
                     reward: np.ndarray) -> float:
        """Project lower-level reward vector to a scalar proxy (inference mode)."""
        if len(reward) == 0:
            return 0.0

        if self.psi is not None:
            psi_input = np.concatenate([observation, reward])
            proxy = self.psi.act(psi_input)
            return float(proxy)
        else:
            return float(reward.sum())

    def proxy_reward_train(self, observation: np.ndarray, reward: np.ndarray,
                           global_step: int = 0) -> float:
        """Project lower-level reward vector to a scalar proxy during training (uses act_train for psi)."""
        if len(reward) == 0:
            return 0.0

        if self.psi is not None:
            psi_input = np.concatenate([observation, reward])
            proxy = self.psi.act_train(psi_input, global_step=global_step)
            return float(proxy)
        else:
            return float(reward.sum())

    def train(self, env: Any, log_path: Path | str | None = None,
              run_name: str | None = None):
        """
        Level agent does not implement standalone training.
        Training happens through update_step() calls from hierarchy.

        This method is implemented only for BaseAgent interface compatibility.
        For actual training, use the agent within a hierarchy where training occurs through:
        - act_train() calls for main policy
        - comm_train() calls for phi function
        - proxy_reward_train() calls for psi function
        - update_step() calls for all components
        """
        raise NotImplementedError(
            f"{self.name} (LevelAgent) does not implement standalone training. "
            "Use this agent within a hierarchy for training. "
            "Training happens through act_train(), comm_train(), proxy_reward_train(), "
            "and update_step() calls from LevelEnv."
        )
