import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from collections import defaultdict, deque

from ...agents.actor import LLMActor, MotifAgent
from ...agents.critic import CentralizedCritic
from ...agents.coordinator import CentralizedCoordinator
from ...environment import AssemblyEnvironment
from ...rewards import RewardSystem


@dataclass
class PPOBatch:
    states: List[Any]
    actions: List[Any]
    log_probs: List[torch.Tensor]
    rewards: List[float]
    values: List[torch.Tensor]
    dones: List[bool]
    advantages: List[torch.Tensor]
    returns: List[torch.Tensor]


class MAPPOTrainer:
    def __init__(self,
                 actor: LLMActor,
                 critic: CentralizedCritic,
                 coordinator: CentralizedCoordinator,
                 reward_system: RewardSystem,
                 learning_rate: float = 3e-4,
                 clip_epsilon: float = 0.2,
                 entropy_coef: float = 0.01,
                 value_loss_coef: float = 0.5,
                 max_grad_norm: float = 0.5,
                 ppo_epochs: int = 4,
                 mini_batch_size: int = 32,
                 gae_lambda: float = 0.95,
                 gamma: float = 0.99,
                 set_bc_weight: float = 1.0,
                 kl_coef: float = 0.01):

        self.actor = actor
        self.critic = critic
        self.coordinator = coordinator
        self.reward_system = reward_system

        # Training hyperparameters
        self.clip_epsilon = clip_epsilon
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.ppo_epochs = ppo_epochs
        self.mini_batch_size = mini_batch_size
        self.gae_lambda = gae_lambda
        self.gamma = gamma
        self.set_bc_weight = set_bc_weight
        self.kl_coef = kl_coef

        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)

        # Training state
        self.training_step = 0
        self.episode_count = 0

        # Experience buffer
        self.experience_buffer = []

        # Training statistics
        self.training_stats = {
            'policy_loss': deque(maxlen=100),
            'value_loss': deque(maxlen=100),
            'entropy_loss': deque(maxlen=100),
            'total_loss': deque(maxlen=100),
            'advantages_mean': deque(maxlen=100),
            'rewards_mean': deque(maxlen=100),
            'episode_lengths': deque(maxlen=100),
            'success_rate': deque(maxlen=100)
        }

    def collect_trajectories(self, env: AssemblyEnvironment, num_episodes: int = 10,
                           max_steps: int = 100) -> List[Dict]:
        trajectories = []

        for episode in range(num_episodes):
            trajectory = self._collect_single_trajectory(env, max_steps)
            trajectories.append(trajectory)

        return trajectories

    def _collect_single_trajectory(self, env: AssemblyEnvironment, max_steps: int) -> Dict:
        trajectory = {
            'states': [],
            'actions': [],
            'rewards': [],
            'log_probs': [],
            'values': [],
            'dones': [],
            'infos': []
        }

        # Reset environment - this would need proper initialization
        # For now, assume env.reset() returns initial state
        state = env.current_state

        # Create motif agents for this episode
        motif_agents = {}
        for motif_id in state.available_motifs:
            motif_agents[motif_id] = MotifAgent(motif_id, self.actor)

        for step in range(max_steps):
            # Get action from coordinator
            action, coordination_info = self.coordinator.coordinate_actions(
                state, motif_agents, temperature=1.0
            )

            # Get action log probability and value
            log_prob = self.actor.get_action_log_prob(state, action)
            critic_outputs = self.critic(state, env.get_action_masks())
            value = critic_outputs['main_value']

            # Execute action
            next_state, reward, terminated, truncated, info = env.step(action)

            # Store experience
            trajectory['states'].append(state)
            trajectory['actions'].append(action)
            trajectory['rewards'].append(reward)
            trajectory['log_probs'].append(log_prob)
            trajectory['values'].append(value)
            trajectory['dones'].append(terminated or truncated)
            trajectory['infos'].append(info)

            state = next_state

            if terminated or truncated:
                break

        self.episode_count += 1
        return trajectory

    def compute_advantages_and_returns(self, trajectory: Dict) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        rewards = trajectory['rewards']
        values = [v.squeeze() for v in trajectory['values']]
        dones = trajectory['dones']

        # Compute GAE advantages
        advantages = []
        returns = []

        # Bootstrap value for non-terminal states
        if not dones[-1]:
            # Last state value for bootstrapping
            next_value = values[-1]  # Could use a better bootstrap estimate
        else:
            next_value = torch.tensor(0.0)

        gae = 0
        for i in reversed(range(len(rewards))):
            if i == len(rewards) - 1:
                next_non_terminal = 1.0 - dones[i]
                next_value_i = next_value
            else:
                next_non_terminal = 1.0 - dones[i]
                next_value_i = values[i + 1]

            delta = rewards[i] + self.gamma * next_value_i * next_non_terminal - values[i]
            gae = delta + self.gamma * self.gae_lambda * next_non_terminal * gae

            advantages.insert(0, torch.tensor(gae))
            returns.insert(0, torch.tensor(gae + values[i]))

        return advantages, returns

    def create_ppo_batch(self, trajectories: List[Dict]) -> PPOBatch:
        all_states = []
        all_actions = []
        all_log_probs = []
        all_rewards = []
        all_values = []
        all_dones = []
        all_advantages = []
        all_returns = []

        for trajectory in trajectories:
            advantages, returns = self.compute_advantages_and_returns(trajectory)

            all_states.extend(trajectory['states'])
            all_actions.extend(trajectory['actions'])
            all_log_probs.extend(trajectory['log_probs'])
            all_rewards.extend(trajectory['rewards'])
            all_values.extend(trajectory['values'])
            all_dones.extend(trajectory['dones'])
            all_advantages.extend(advantages)
            all_returns.extend(returns)

        # Normalize advantages
        advantages_tensor = torch.stack(all_advantages)
        normalized_advantages = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)
        all_advantages = [normalized_advantages[i] for i in range(len(all_advantages))]

        return PPOBatch(
            states=all_states,
            actions=all_actions,
            log_probs=all_log_probs,
            rewards=all_rewards,
            values=all_values,
            dones=all_dones,
            advantages=all_advantages,
            returns=all_returns
        )

    def train_ppo_step(self, batch: PPOBatch) -> Dict[str, float]:
        total_policy_loss = 0.0
        total_value_loss = 0.0
        total_entropy_loss = 0.0
        total_bc_loss = 0.0
        num_updates = 0

        # Create mini-batches
        batch_size = len(batch.states)
        indices = np.random.permutation(batch_size)

        for epoch in range(self.ppo_epochs):
            for start in range(0, batch_size, self.mini_batch_size):
                end = min(start + self.mini_batch_size, batch_size)
                mini_batch_indices = indices[start:end]

                # Extract mini-batch
                mb_states = [batch.states[i] for i in mini_batch_indices]
                mb_actions = [batch.actions[i] for i in mini_batch_indices]
                mb_old_log_probs = torch.stack([batch.log_probs[i] for i in mini_batch_indices])
                mb_advantages = torch.stack([batch.advantages[i] for i in mini_batch_indices])
                mb_returns = torch.stack([batch.returns[i] for i in mini_batch_indices])
                mb_old_values = torch.stack([batch.values[i] for i in mini_batch_indices]).squeeze()

                # Compute current policy outputs
                mb_log_probs = []
                mb_values = []
                mb_entropies = []

                for state, action in zip(mb_states, mb_actions):
                    # Actor forward pass
                    log_prob = self.actor.get_action_log_prob(state, action)
                    mb_log_probs.append(log_prob)

                    # Critic forward pass
                    critic_outputs = self.critic(state)
                    mb_values.append(critic_outputs['main_value'].squeeze())

                    # Entropy calculation (simplified)
                    actor_outputs = self.actor(state)
                    entropy = self._compute_entropy(actor_outputs['action_logits'])
                    mb_entropies.append(entropy)

                mb_log_probs = torch.stack(mb_log_probs)
                mb_values = torch.stack(mb_values)
                mb_entropies = torch.stack(mb_entropies)

                # PPO policy loss
                ratio = torch.exp(mb_log_probs - mb_old_log_probs)
                surr1 = ratio * mb_advantages
                surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * mb_advantages
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                value_loss = nn.MSELoss()(mb_values, mb_returns)

                # Entropy loss
                entropy_loss = -mb_entropies.mean()

                # Set-BC loss (simplified - would need proper implementation)
                bc_loss = torch.tensor(0.0)

                # Total loss
                total_loss = (policy_loss +
                            self.value_loss_coef * value_loss +
                            self.entropy_coef * entropy_loss +
                            self.set_bc_weight * bc_loss)

                # Update actor
                self.actor_optimizer.zero_grad()
                actor_loss = policy_loss + self.entropy_coef * entropy_loss + self.set_bc_weight * bc_loss
                actor_loss.backward(retain_graph=True)
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_optimizer.step()

                # Update critic
                self.critic_optimizer.zero_grad()
                critic_loss = value_loss
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_optimizer.step()

                # Accumulate losses
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy_loss += entropy_loss.item()
                total_bc_loss += bc_loss.item()
                num_updates += 1

        self.training_step += 1

        # Return average losses
        return {
            'policy_loss': total_policy_loss / num_updates,
            'value_loss': total_value_loss / num_updates,
            'entropy_loss': total_entropy_loss / num_updates,
            'bc_loss': total_bc_loss / num_updates,
            'total_loss': (total_policy_loss + total_value_loss + total_entropy_loss + total_bc_loss) / num_updates
        }

    def _compute_entropy(self, action_logits: Dict[str, torch.Tensor]) -> torch.Tensor:
        # Compute entropy across all action components
        total_entropy = torch.tensor(0.0)

        for logits in action_logits.values():
            if logits.numel() > 0:
                probs = torch.softmax(logits, dim=-1)
                log_probs = torch.log_softmax(logits, dim=-1)
                entropy = -(probs * log_probs).sum(dim=-1).mean()
                total_entropy += entropy

        return total_entropy

    def update_training_stats(self, batch: PPOBatch, loss_dict: Dict[str, float]):
        # Update training statistics
        self.training_stats['policy_loss'].append(loss_dict['policy_loss'])
        self.training_stats['value_loss'].append(loss_dict['value_loss'])
        self.training_stats['entropy_loss'].append(loss_dict['entropy_loss'])
        self.training_stats['total_loss'].append(loss_dict['total_loss'])

        # Batch statistics
        advantages_tensor = torch.stack(batch.advantages)
        rewards_tensor = torch.tensor(batch.rewards)

        self.training_stats['advantages_mean'].append(advantages_tensor.mean().item())
        self.training_stats['rewards_mean'].append(rewards_tensor.mean().item())

    def train(self, env: AssemblyEnvironment, num_iterations: int = 1000,
              episodes_per_iteration: int = 10):
        for iteration in range(num_iterations):
            # Collect trajectories
            trajectories = self.collect_trajectories(env, episodes_per_iteration)

            # Create batch
            batch = self.create_ppo_batch(trajectories)

            # Training step
            loss_dict = self.train_ppo_step(batch)

            # Update statistics
            self.update_training_stats(batch, loss_dict)

            # Print progress
            if iteration % 10 == 0:
                avg_reward = np.mean(self.training_stats['rewards_mean'])
                print(f"Iteration {iteration}, Avg Reward: {avg_reward:.3f}, "
                      f"Policy Loss: {loss_dict['policy_loss']:.4f}")

    def get_training_stats(self) -> Dict[str, Any]:
        stats = {}
        for key, values in self.training_stats.items():
            if values:
                stats[key] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values)
                }

        stats['training_step'] = self.training_step
        stats['episode_count'] = self.episode_count

        return stats

    def save_checkpoint(self, filepath: str):
        checkpoint = {
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'training_step': self.training_step,
            'episode_count': self.episode_count,
            'training_stats': dict(self.training_stats)
        }

        torch.save(checkpoint, filepath)

    def load_checkpoint(self, filepath: str):
        checkpoint = torch.load(filepath)

        self.actor.load_state_dict(checkpoint['actor_state_dict'])
        self.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
        self.training_step = checkpoint['training_step']
        self.episode_count = checkpoint['episode_count']

        # Restore training stats
        for key, values in checkpoint['training_stats'].items():
            self.training_stats[key] = deque(values, maxlen=100)