import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from utils import prepare_graph_data, encode_graph_condition

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class LatentPolicyNetwork(nn.Module):
    """
    Policy network for RL in latent space
    """
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(LatentPolicyNetwork, self).__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # Shared feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2)
        )
        
        # Policy head (mean of action distribution)
        self.policy_mean = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, action_dim)
        )
        
        # Policy log standard deviation
        # Not a function of state, learned directly as parameters
        self.policy_log_std = nn.Parameter(torch.zeros(action_dim))
        
        # Value function head
        self.value = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, state):
        """Forward pass through the network"""
        features = self.feature_extractor(state)
        
        # Mean of action distribution
        action_mean = self.policy_mean(features)
        
        # Scale of action distribution (log std -> std)
        action_std = torch.exp(self.policy_log_std).expand_as(action_mean)
        
        # Value function
        value = self.value(features)
        
        return action_mean, action_std, value
    
    def get_action(self, state, deterministic=False):
        """
        Sample action from the policy
        
        Args:
            state: Current state
            deterministic: If True, return the mean action
        
        Returns:
            action: Sampled action
            log_prob: Log probability of the action
            entropy: Entropy of the distribution
        """
        action_mean, action_std, _ = self.forward(state)
        
        if deterministic:
            return action_mean, None, None
        
        # Create normal distribution
        normal = torch.distributions.Normal(action_mean, action_std)
        
        # Sample action
        action = normal.sample()
        
        # Calculate log probability and entropy
        log_prob = normal.log_prob(action).sum(dim=-1)
        entropy = normal.entropy().sum(dim=-1)
        
        return action, log_prob, entropy
    
    def evaluate_action(self, state, action):
        """
        Evaluate the log probability and entropy of an action
        
        Args:
            state: State
            action: Action to evaluate
        
        Returns:
            log_prob: Log probability of the action
            entropy: Entropy of the distribution
            value: Value function output
        """
        action_mean, action_std, value = self.forward(state)
        
        # Create normal distribution
        normal = torch.distributions.Normal(action_mean, action_std)
        
        # Calculate log probability and entropy
        log_prob = normal.log_prob(action).sum(dim=-1)
        entropy = normal.entropy().sum(dim=-1)
        
        return log_prob, entropy, value

import math  # Add this import at the top of the file

def compute_reward(G, K, T, perturbation, spagan_model, data, cost_penalty=0.05):
    """
    Compute the reward for a perturbation
    
    Args:
        G: NetworkX graph
        K: List of (source, target) pairs
        T: Threshold
        perturbation: Perturbation vector
        spagan_model: SPAGAN model
        data: PyTorch Geometric Data object
        cost_penalty: Weight for cost penalty term
    
    Returns:
        reward: Scalar reward value
    """
    # Smooth feasibility term
    zeta = 5.0  # Sharpness parameter for sigmoid
    
    # Apply softplus to perturbation to ensure smoothness
    perturbation_smooth = F.softplus(perturbation)
    
    # Round perturbation for feasibility check (for SPAGAN prediction)
    perturbation_rounded = torch.round(perturbation)
    
    # Calculate smooth feasibility score
    feasibility_score = 0.0
    
    for s, t in K:
        s_idx = torch.tensor([s], device=device)
        t_idx = torch.tensor([t], device=device)
        
        # Predict path cost using the SPAGAN model
        pred_cost = spagan_model(data, s_idx, t_idx, perturbation=perturbation_rounded).item()
        
        # Apply sigmoid to get smooth feasibility
        # Convert to tensor for torch.exp or use math.exp for float
        if isinstance(pred_cost, float) and isinstance(T, float):
            pair_score = 1.0 / (1.0 + math.exp(-zeta * (pred_cost - T)))
        else:
            # Convert to tensors if needed
            pred_cost_tensor = pred_cost if isinstance(pred_cost, torch.Tensor) else torch.tensor(pred_cost, device=device)
            T_tensor = T if isinstance(T, torch.Tensor) else torch.tensor(T, device=device)
            pair_score = 1.0 / (1.0 + torch.exp(-zeta * (pred_cost_tensor - T_tensor)))
            pair_score = pair_score.item()  # Convert back to Python scalar
        
        feasibility_score += pair_score
    
    # Cost penalty
    total_cost = torch.log(1 + perturbation_smooth.norm(p=1))
    
    # Reward is feasibility minus cost penalty
    reward = feasibility_score - cost_penalty * total_cost.item()
    
    return reward



def refine_phase(ebm, mix_cvae, solution_dataset, spagan_model, num_episodes=5000, 
                 gamma=0.99, lr=5e-5, reward_cost_weight=0.05):
    """
    Implements the Refine phase using RL in latent space
    
    Args:
        ebm: Trained Energy-Based Model
        mix_cvae: Trained Mixture of CVAEs
        solution_dataset: Dataset from Forge phase
        spagan_model: Trained SPAGAN model
        num_episodes: Number of RL episodes
        gamma: Discount factor
        lr: Learning rate
        reward_cost_weight: Weight for cost penalty in reward
    
    Returns:
        policy: Trained RL policy
        refined_solutions: List of improved solutions
    """
    # Extract dimensions
    sample_G, sample_K, sample_T, sample_perturbation_dict, sample_perturbation_vector = solution_dataset[0]
    edge_dim = len(sample_perturbation_vector)
    
    # Encode sample condition to get dimension
    sample_condition = encode_graph_condition(sample_G, sample_K, sample_T)
    condition_dim = sample_condition.size(0)
    
    # Get CVAE from the mixture model for latent operations
    # (Using the first expert for simplicity, but could use any)
    cvae = mix_cvae.experts[0]
    latent_dim = cvae.latent_dim
    
    # State dimension: latent vector + condition
    state_dim = latent_dim + condition_dim
    
    # Action dimension: perturbation to the latent vector
    action_dim = latent_dim
    
    # Initialize policy
    policy = LatentPolicyNetwork(state_dim, action_dim, hidden_dim=256).to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    
    # Storage for refined solutions
    refined_solutions = []
    
    # PPO parameters
    ppo_epochs = 5
    clip_epsilon = 0.2
    value_coef = 0.5
    entropy_coef = 0.01
    max_grad_norm = 0.5
    
    # Training metrics
    avg_rewards = []
    episode_lengths = []
    
    for episode in range(num_episodes):
        # Sample an instance from the dataset
        idx = random.randint(0, len(solution_dataset) - 1)
        G, K, T, perturbation_dict, perturbation_vector = solution_dataset[idx]
        
        # Create PyG Data object for the graph
        data, node_mapping = prepare_graph_data(G)
        data = data.to(device)
        
        # Map critical pairs to node indices if needed
        if node_mapping:
            K_mapped = [(node_mapping[s], node_mapping[t]) for s, t in K]
        else:
            K_mapped = K
        
        # Encode condition
        condition = encode_graph_condition(G, K, T).unsqueeze(0)
        
        # Encode perturbation to get initial latent vector
        with torch.no_grad():
            mu, _ = cvae.encode(perturbation_vector.unsqueeze(0), condition)
            z_current = mu.clone()  # Start from mean of the posterior
        
        # Storage for trajectory
        states = []
        actions = []
        log_probs = []
        values = []
        rewards = []
        entropies = []
        
        # Start trajectory
        trajectory_length = 0
        max_trajectory_length = 50
        
        while trajectory_length < max_trajectory_length:
            trajectory_length += 1
            
            # Create state from latent vector and condition
            state = torch.cat([z_current, condition], dim=1)
            
            # Get action from policy
            action, log_prob, entropy = policy.get_action(state)
            
            # Apply action to latent vector
            z_next = z_current + action
            
            # Decode latent vector to get perturbation
            with torch.no_grad():
                perturbation_next = cvae.decode(z_next, condition)
            
            # Compute reward
            reward = compute_reward(
                G, K_mapped, T, perturbation_next.squeeze(),
                spagan_model, data, cost_penalty=reward_cost_weight
            )
            
            # Get value from policy
            _, _, value = policy(state)
            
            # Store trajectory
            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            values.append(value.squeeze())
            rewards.append(reward)
            entropies.append(entropy)
            
            # Update current latent vector
            z_current = z_next
            
            # Check if we should terminate early (very high reward)
            if reward > 0.95 * len(K_mapped):
                break
        
        # Calculate returns and advantages
        returns = []
        advantages = []
        
        # Convert rewards to tensor
        rewards_tensor = torch.tensor(rewards, dtype=torch.float, device=device)
        
        # Compute returns and advantages using GAE (Generalized Advantage Estimation)
        next_value = 0
        next_advantage = 0
        
        for t in reversed(range(len(rewards))):
            # Compute return (discounted sum of rewards)
            returns.insert(0, rewards_tensor[t] + gamma * next_value)
            
            # Compute TD error
            delta = rewards_tensor[t] + gamma * next_value - values[t]
            
            # Compute advantage
            advantages.insert(0, delta + gamma * 0.95 * next_advantage)
            
            # Update for next iteration
            next_value = values[t]
            next_advantage = advantages[0]
        
        # Convert lists to tensors
        states = torch.cat(states, dim=0)
        actions = torch.cat(actions, dim=0)
        log_probs = torch.stack(log_probs)
        returns = torch.stack(returns)
        advantages = torch.stack(advantages)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO update
        for _ in range(ppo_epochs):
            # Ensure states and actions are detached to avoid tracking old graph
            states_detached = states.detach().requires_grad_(True)
            actions_detached = actions.detach().requires_grad_(True)
            
            # Evaluate actions with fresh computation
            new_log_probs, new_entropy, new_values = policy.evaluate_action(states_detached, actions_detached)
            
            # Compute ratio (policy/old_policy)
            ratio = torch.exp(new_log_probs - log_probs.detach())
            
            # Compute surrogate losses
            surrogate1 = ratio * advantages
            surrogate2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
            
            # Compute actor loss
            actor_loss = -torch.min(surrogate1, surrogate2).mean()
            
            # Compute critic loss
            critic_loss = F.mse_loss(new_values.squeeze(), returns)
            
            # Compute entropy loss
            entropy_loss = -new_entropy.mean()
            
            # Total loss
            loss = actor_loss + value_coef * critic_loss + entropy_coef * entropy_loss
            
            # Update policy
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm)
            optimizer.step()
        
        # After training, generate and refine a solution
        if episode % 100 == 0 or episode == num_episodes - 1:
            # Sample condition (could be from the dataset or a new one)
            G, K, T, _, _ = solution_dataset[random.randint(0, len(solution_dataset) - 1)]
            
            # Create PyG Data object
            data, node_mapping = prepare_graph_data(G)
            data = data.to(device)
            
            # Map critical pairs
            if node_mapping:
                K_mapped = [(node_mapping[s], node_mapping[t]) for s, t in K]
            else:
                K_mapped = K
            
            # Encode condition
            condition = encode_graph_condition(G, K, T).unsqueeze(0)
            
            # Sample from the prior
            z = torch.randn(1, latent_dim, device=device)
            
            # Refine the latent vector using the policy
            for _ in range(10):
                state = torch.cat([z, condition], dim=1)
                action, _, _ = policy.get_action(state, deterministic=True)
                z = z + action
            
            # Decode to get perturbation
            with torch.no_grad():
                perturbation = cvae.decode(z, condition).squeeze()
            
            # Ensure feasibility using PPS-I (variant of PPS that uses exact shortest path)
            perturbation_dict = {}
            for i, val in enumerate(perturbation):
                if val > 0:
                    # Map back to edge (simplification: assuming sequential edge indexing)
                    edge = list(G.edges())[i]
                    perturbation_dict[edge] = val.item()
            
            # Check feasibility
            feasible = True
            for s, t in K_mapped:
                s_idx = torch.tensor([s], device=device)
                t_idx = torch.tensor([t], device=device)
                pred_cost = spagan_model(data, s_idx, t_idx, perturbation=torch.round(perturbation)).item()
                if pred_cost < T:
                    feasible = False
                    break
            
            # Store if feasible
            if feasible:
                total_budget = sum(perturbation_dict.values())
                refined_solutions.append((G, K, T, perturbation_dict, total_budget))
                print(f"Episode {episode}: Found feasible solution with budget {total_budget:.2f}")
            else:
                print(f"Episode {episode}: Solution not feasible, applying PPS-I")
                # Apply PPS-I to ensure feasibility (implementation would be similar to PPS)
                # For brevity, we skip the full implementation here
        
        # Log progress
        avg_rewards.append(sum(rewards) / len(rewards))
        episode_lengths.append(trajectory_length)
        
        if episode % 10 == 0:
            print(f"Episode {episode}/{num_episodes}")
            print(f"  Average reward: {avg_rewards[-1]:.4f}")
            print(f"  Episode length: {episode_lengths[-1]}")
    
    return policy, refined_solutions