from typing import Dict, Any, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import gymnasium as gym
import copy

from omegaconf import OmegaConf

from benchrl.algorithms.base import BaseAlgorithm
from benchrl.policies.sac_policy import DiscreteSACPolicy, ContinuousSACPolicy
from benchrl.utils.buffers.replay_buffer import ReplayBuffer
from benchrl.utils._functions import _build_network

class SAC(BaseAlgorithm):
    """Soft Actor-Critic (SAC) algorithm.
    
    Supports both discrete and continuous action spaces with automatic entropy tuning.
    Uses twin Q-networks for improved stability and extends the unified ActorCritic framework.
    
    Reference: https://arxiv.org/abs/1801.01290
    """
    
    def __init__(
        self,
        env,
        algo_config: Dict[str, Any],
        device: str = "auto",
        writer: Optional[SummaryWriter] = None,
    ):
        """Initialize SAC algorithm.
        
        Args:
            env: Vectorized environment (should be single env for SAC)
            algo_config: Algorithm configuration
            device: Device for computation
            writer: Tensorboard writer
        """
        super().__init__(env, algo_config, device, writer)
        
        # SAC supports multiple environments
        # Note: SAC typically uses single environment, but we support multiple for consistency
        
        # Determine action space type
        self.is_discrete = isinstance(self.action_space, gym.spaces.Discrete)
        self.is_continuous = isinstance(self.action_space, gym.spaces.Box)
        
        if not (self.is_discrete or self.is_continuous):
            raise ValueError("SAC only supports Discrete or Box action spaces")
        
        # Extract config parameters with defaults
        self.total_timesteps = algo_config.get('total_timesteps', 1000000)
        self.buffer_size = algo_config.get('buffer_size', 1000000)
        self.learning_starts = algo_config.get('learning_starts', 25000)
        self.batch_size = algo_config.get('batch_size', 256)
        self.tau = algo_config.get('tau', 0.005)
        self.gamma = algo_config.get('gamma', 0.99)
        self.train_freq = algo_config.get('train_freq', 1)
        self.gradient_steps = algo_config.get('gradient_steps', 1)
        self.target_update_interval = algo_config.get('target_update_interval', 1)
        self.policy_frequency = algo_config.get('policy_frequency', 2)
        
        # Learning rates (CleanRL defaults)
        self.actor_lr = algo_config.get('actor_lr', 3e-4)
        self.critic_lr = algo_config.get('critic_lr', 1e-3)
        
        # Entropy coefficient
        self.ent_coeff = algo_config.get('ent_coeff', 'auto')
        self.target_entropy = algo_config.get('target_entropy', None)
         
        # Build actor-critic policy
        # Build networks using the new modular architecture
        print("Building networks with modular architecture...")
        modules = self.build_module()
        print(f"Networks: {list(modules.keys())}")
        
        actor_network = modules.get('actor_network', None)
        critic_network = modules.get('critic1_network', None)
        # shared_backbone = modules.get('shared_backbone', None)
        
        if actor_network is None or critic_network is None:
            raise ValueError("SAC requires actor_network and critic_network")
        
        # Build SAC policy based on action space
        if self.is_discrete:
            self.policy = DiscreteSACPolicy(
                actor_network=actor_network,
                critic_network=critic_network,
                action_dim=self.action_dim
            )
        else:
            self.policy = ContinuousSACPolicy(
                actor_network=actor_network,
                critic_network=critic_network,
                action_dim=self.action_dim
            )
            # action rescaling
            self.policy.actor_network.register_buffer(
                "action_scale",
                torch.tensor(
                    (env.single_action_space.high - env.single_action_space.low) / 2.0,
                    dtype=torch.float32,
                ),
            )
            self.policy.actor_network.register_buffer(
                "action_bias",
                torch.tensor(
                    (env.single_action_space.high + env.single_action_space.low) / 2.0,
                    dtype=torch.float32,
                ),
            )
        
        # Create target critics (deep copy for stability)
        self.target_critic1 = copy.deepcopy(self.policy.critic_network)
        self.target_critic2 = copy.deepcopy(self.policy.critic2_network)
        
        # Move to device
        self.policy.to(self.device)
        self.target_critic1.to(self.device)
        self.target_critic2.to(self.device)
        
        # Initialize optimizers with separate actor and critic parameters (CleanRL style)
        self.actor_optimizer = optim.Adam(
            self.policy.actor_parameters(),
            lr=self.actor_lr,
            # eps=1e-4
        )
        
        self.critic_optimizer = optim.Adam(
            self.policy.critic_parameters(),
            lr=self.critic_lr,
            # eps=1e-4
        )
        
        # Entropy coefficient setup
        if self.ent_coeff == 'auto':
            if self.target_entropy is None:
                if self.is_discrete:
                    # For discrete: target entropy = -log(1/|A|) * ratio
                    target_entropy_scale = algo_config.get('target_entropy_scale', 0.89)
                    self.target_entropy = -target_entropy_scale * torch.log(1 / torch.tensor(self.action_dim))
                else:
                    # For continuous: target entropy = -dim(A)
                    self.target_entropy = -torch.prod(torch.Tensor(self.action_space.shape).to(self.device)).item()
            
            # Learnable entropy coefficient (CleanRL style)
            self.log_ent_coeff = torch.zeros(1, requires_grad=True, device=self.device)
            self.ent_coeff_optimizer = optim.Adam([self.log_ent_coeff], lr=self.critic_lr, eps=1e-4)
        else:
            self.log_ent_coeff = None
            self.ent_coeff_optimizer = None
            if isinstance(self.ent_coeff, str):
                raise ValueError(f"Invalid entropy coefficient: {self.ent_coeff}")
        
        # Initialize replay buffer
        self.replay_buffer = ReplayBuffer(
            buffer_size=self.buffer_size,
            observation_space=self.env.single_observation_space,
            action_space=self.env.single_action_space,
            device=self.device,
            n_envs=self.env.num_envs,
            handle_timeout_termination=False,
        )
        
        # Initialize environment state
        self.last_obs = None
        
        # Initialize metrics:
        self.qf1_a_values = torch.zeros(1, device=self.device)
        self.qf2_a_values = torch.zeros(1, device=self.device)
        self.qf1_loss = torch.zeros(1, device=self.device)
        self.qf2_loss = torch.zeros(1, device=self.device)
        self.qf_loss = torch.zeros(1, device=self.device)
        self.actor_loss = torch.zeros(1, device=self.device)
        # Pre-compute alpha (avoid repeated exp() calls)
        self.alpha = self.log_ent_coeff.exp().item() if self.ent_coeff == 'auto' else self.ent_coeff
        self.alpha_loss = torch.zeros(1, device=self.device) if self.ent_coeff == 'auto' else None

        # Compile networks if requested
        if algo_config.get('compile', False):
            torch.set_float32_matmul_precision('high')
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = True
            self.policy = self.policy._compile(
                mode='default',
                fullgraph=True
            )
    
    def build_module(self):
        """
        Build models based on configuration bindings.
        
        Supports separate networks: Each role gets its own network

            
        Returns:
            Dictionary mapping roles to model instances
        """
        model_bindings = OmegaConf.to_container(self.algo_config.model_bindings, resolve=True)

        # Remove shared backbone if present
        model_bindings.pop('shared_backbone', None)
        
        # Build role-specific models
        modules = {}
        for role, model_spec in model_bindings.items():
            if not isinstance(model_spec, str):
                raise ValueError(f"Model spec for role '{role}' must be a string refering to module config name, got {type(model_spec)}")
            if "critic" in role:
                if not self.is_discrete: # Continuous SAC
                    self.in_dim = self.flat_in_dim + self.action_dim
                    self.out_dim = 1
                modules[role] = _build_network(model_spec, self.in_dim, self.out_dim)
            elif "actor" in role:
                if not self.is_discrete:
                    self.in_dim = self.flat_in_dim
                    self.out_dim =[self.action_dim,  self.action_dim]
                modules[role] = _build_network(model_spec, self.in_dim, self.out_dim)
            
        return modules

    def collect_rollouts(self, num_steps: int = 1) -> Dict[str, float]:
        """Collect experience for replay buffer."""
        if self.last_obs is None:
            self.last_obs, _ = self.env.reset()
        
        episode_returns = []
        episode_lengths = []
        
        # Pre-allocate tensors to avoid repeated allocation
        if not hasattr(self, '_obs_tensor'):
            self._obs_tensor = torch.zeros((self.env.num_envs, *self.env.single_observation_space.shape), 
                                         dtype=torch.float32, device=self.device)
        
        for _ in range(num_steps):
            self.global_step += self.env.num_envs
            
            # Copy observations to pre-allocated tensor (faster than creating new tensor)
            self._obs_tensor.copy_(torch.from_numpy(self.last_obs))
            
            with torch.no_grad():
                if self.global_step < self.learning_starts:
                    # Vectorized random action sampling
                    if self.is_discrete:
                        actions = np.random.randint(0, self.action_space.n, size=(self.env.num_envs,))
                    else:
                        actions = np.random.uniform(
                            self.action_space.low, self.action_space.high, 
                            size=(self.env.num_envs, self.action_space.shape[0])
                        )
                else:
                    actions, _, _ = self.policy.get_action(self._obs_tensor, deterministic=False)
                    actions = actions.cpu().numpy()
            
            # Take environment step
            next_obs, rewards, terminations, truncations, infos = self.env.step(actions)
            
            # TRY NOT TO MODIFY: record rewards for plotting purposes
            rollout_episodes = 0
            if "final_info" in infos:
                eps_final_info = infos["final_info"]['episode']
                rollout_episodes = len(eps_final_info["r"])
                self.episode_count += rollout_episodes
                for i in range(len(eps_final_info['r'])):
                        episode_returns.append(eps_final_info["r"][i])
                        episode_lengths.append(eps_final_info["l"][i])
                        break
            
            # Final observation handling
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    real_next_obs[idx] = infos["final_obs"][idx]

            # Store transitions in replay buffer
            self.replay_buffer.add(
                obs=self.last_obs,
                next_obs=real_next_obs,
                action=actions,
                reward=rewards,
                done=terminations,
                infos=infos
            )
            
            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            self.last_obs = next_obs
            
        
        # Return rollout metrics
        metrics = {}
        if episode_returns:
            metrics.update({
                'rollout/episodic_return': episode_returns,
                'rollout/episodic_length': episode_lengths,
                'rollout/episodes': rollout_episodes
            })
        
        return metrics
    
    def train_step(self) -> Dict[str, float]:
        """Execute one training step (optimized for speed)."""
        # Collect experience
        rollout_metrics = self.collect_rollouts(self.train_freq)
        
        # Skip training if not enough samples
        if self.global_step < self.learning_starts:
            return rollout_metrics
        
        # Sample batch from replay buffer
        data = self.replay_buffer.sample(self.batch_size)

        # CRITIC training - compute target Q-values
        with torch.no_grad():
            if self.is_discrete:
                _, next_state_log_pi, next_state_action_probs = self.policy.get_action(data.next_observations)
                qf1_next_target = self.target_critic1(data.next_observations)
                qf2_next_target = self.target_critic2(data.next_observations)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
                next_q_value = (next_state_action_probs * (min_qf_next_target - self.alpha * next_state_log_pi.unsqueeze(1))).sum(dim=1)
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * self.gamma * next_q_value
            else:
                next_state_actions, next_state_log_pi, _ = self.policy.get_action(data.next_observations)
                qf1_next_target = self.target_critic1(data.next_observations, next_state_actions).flatten()
                qf2_next_target = self.target_critic2(data.next_observations, next_state_actions).flatten()
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi.flatten()
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * self.gamma * min_qf_next_target
        
        # Current Q-values (with gradients for critic training)
        if self.is_discrete:
            self.qf1_a_values = self.policy.critic_network(data.observations).gather(1, data.actions.long()).flatten()
            self.qf2_a_values = self.policy.critic2_network(data.observations).gather(1, data.actions.long()).flatten()
        else:
            self.qf1_a_values = self.policy.critic_network(data.observations, data.actions).flatten()
            self.qf2_a_values = self.policy.critic2_network(data.observations, data.actions).flatten()
        
        self.qf1_loss = F.mse_loss(self.qf1_a_values, next_q_value)
        self.qf2_loss = F.mse_loss(self.qf2_a_values, next_q_value)
        self.qf_loss = self.qf1_loss + self.qf2_loss
        
        # Optimize critics
        self.critic_optimizer.zero_grad()
        self.qf_loss.backward()
        self.critic_optimizer.step()
        
        # ACTOR training - single actor forward pass with gradients
        if self.global_step % self.policy_frequency == 0:  # TD 3 Delayed update support
            for _ in range(
                self.policy_frequency
            ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                if self.is_discrete:
                    _, log_pi, action_probs = self.policy.get_action(data.observations)
                    with torch.no_grad():
                        qf1_values = self.policy.critic_network(data.observations)
                        qf2_values = self.policy.critic2_network(data.observations)
                        min_qf_values = torch.min(qf1_values, qf2_values)

                    self.actor_loss = (action_probs * ((self.alpha * log_pi.unsqueeze(1)) - min_qf_values)).mean()

                    if self.ent_coeff == 'auto':
                        self.alpha_loss = (action_probs.detach() * (-self.log_ent_coeff.exp() * (log_pi.unsqueeze(1).detach() + self.target_entropy))).mean()
                else:
                    pi, log_pi, _ = self.policy.get_action(data.observations)
                    qf1_pi = self.policy.critic_network(data.observations, pi).flatten()
                    qf2_pi = self.policy.critic2_network(data.observations, pi).flatten()
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    
                    self.actor_loss = (self.alpha * log_pi.flatten() - min_qf_pi).mean()
                    
                    if self.ent_coeff == 'auto':
                        self.alpha_loss = (-self.log_ent_coeff.exp() * (log_pi.detach().flatten() + self.target_entropy)).mean()
                
                # Optimize actor
                self.actor_optimizer.zero_grad()
                self.actor_loss.backward()
                self.actor_optimizer.step()

                # Optimize alpha (entropy coefficient)
                if self.ent_coeff == 'auto':
                    self.ent_coeff_optimizer.zero_grad()
                    self.alpha_loss.backward()
                    self.ent_coeff_optimizer.step()
                    self.alpha = self.log_ent_coeff.exp().item()
        
        # Update target networks
        if self.global_step % self.target_update_interval == 0:
            self._update_target_networks()
        
        # Training metrics
        training_metrics = {
            'train/qf1_values': self.qf1_a_values.mean().item(),
            'train/qf2_values': self.qf2_a_values.mean().item(),
            'train/qf1_loss': self.qf1_loss.item(),
            'train/qf2_loss': self.qf2_loss.item(),
            'train/qf_loss': self.qf_loss.item() / 2.0,
            'train/actor_loss': self.actor_loss.item(),
            'train/alpha': self.alpha,
        }
        
        if self.ent_coeff == 'auto':
            training_metrics['train/alpha_loss'] = self.alpha_loss.item()
        
        # Combine all metrics
        all_metrics = {**rollout_metrics, **training_metrics}
        
        return all_metrics
    
    def _update_target_networks(self):
        """Soft update target networks."""
        # Update target critic 1
        for param, target_param in zip(self.policy.critic_network.parameters(), self.target_critic1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        # Update target critic 2
        for param, target_param in zip(self.policy.critic2_network.parameters(), self.target_critic2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def save(self, path: str) -> None:
        """Save algorithm state."""
        checkpoint = {
            'policy_state_dict': self.policy.state_dict(),
            'target_critic1_state_dict': self.target_critic1.state_dict(),
            'target_critic2_state_dict': self.target_critic2.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'global_step': self.global_step,
            'episode_count': self.episode_count,
            'algo_config': self.algo_config,
        }
        
        if self.ent_coeff_optimizer is not None:
            checkpoint['ent_coeff_optimizer_state_dict'] = self.ent_coeff_optimizer.state_dict()
            checkpoint['log_ent_coeff'] = self.log_ent_coeff
        
        torch.save(checkpoint, path)
    
    def load(self, path: str) -> None:
        """Load algorithm state."""
        checkpoint = torch.load(path, map_location=self.device)
        
        self.policy.load_state_dict(checkpoint['policy_state_dict'])
        self.target_critic1.load_state_dict(checkpoint['target_critic1_state_dict'])
        self.target_critic2.load_state_dict(checkpoint['target_critic2_state_dict'])
        
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
        
        if 'ent_coeff_optimizer_state_dict' in checkpoint and self.ent_coeff_optimizer is not None:
            self.ent_coeff_optimizer.load_state_dict(checkpoint['ent_coeff_optimizer_state_dict'])
            self.log_ent_coeff = checkpoint['log_ent_coeff']
        
        self.global_step = checkpoint['global_step']
        self.episode_count = checkpoint['episode_count']