"""
GFlowNet implementation with proper trajectory balance loss.
Based on EGFN repository: https://github.com/zarifikram/egfn
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import math
from collections import defaultdict


class GFlowNetPolicy(nn.Module):
    """
    GFlowNet policy network with proper forward and backward policies.
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, 
                 num_layers: int = 3, dropout: float = 0.1):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        
        # Build network layers
        layers = []
        input_dim = state_dim
        
        for i in range(num_layers):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.LeakyReLU())
            input_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(hidden_dim, action_dim))
        
        self.network = nn.Sequential(*layers)
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Initialize network weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass returning action logits."""
        if state.dim() == 1:
            state = state.unsqueeze(0)
        return self.network(state)
    
    def get_action_probs(self, state: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Get action probabilities with optional masking."""
        logits = self.forward(state)
        
        if mask is not None:
            # Apply mask (set invalid actions to large negative value)
            logits = logits.masked_fill(~mask, -1e9)
        
        return F.softmax(logits, dim=-1)
    
    def sample_action(self, state: torch.Tensor, mask: Optional[torch.Tensor] = None) -> int:
        """Sample action from policy."""
        probs = self.get_action_probs(state, mask)
        if probs.dim() > 1:
            probs = probs.squeeze(0)
        
        # Sample from categorical distribution
        action = torch.multinomial(probs, 1).item()
        return action


class GFlowNetBase(nn.Module):
    """
    Base GFlowNet implementation with trajectory balance loss.
    Implementation following GFlowNet principles.
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, 
                 num_layers: int = 3, lr: float = 5e-4, device: Optional[torch.device] = None):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Forward and backward policies
        self.forward_policy = GFlowNetPolicy(
            state_dim, action_dim, hidden_dim, num_layers
        ).to(self.device)
        
        self.backward_policy = GFlowNetPolicy(
            state_dim, action_dim, hidden_dim, num_layers
        ).to(self.device)
        
        # State flow function (log Z)
        self.log_flow = GFlowNetPolicy(
            state_dim, 1, hidden_dim, num_layers
        ).to(self.device)
        
        # Training metrics
        self.metrics = defaultdict(list)
        self.step_count = 0
    
    def get_action_probs(self, state: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Get action probabilities from forward policy."""
        return self.forward_policy.get_action_probs(state, mask)
        
    def get_forward_prob(self, state: torch.Tensor, action: int, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Get forward probability P_F(a|s)."""
        probs = self.forward_policy.get_action_probs(state, mask)
        if probs.dim() > 1:
            return probs[0, action]
        return probs[action]
    
    def get_backward_prob(self, state: torch.Tensor, parent_state: torch.Tensor) -> torch.Tensor:
        """Get backward probability P_B(s'|s)."""
        # In practice, this depends on the environment structure
        # For now, use uniform backward probability
        probs = self.backward_policy.get_action_probs(state)
        return probs.mean()  # Simplified
    
    def get_log_flow(self, state: torch.Tensor) -> torch.Tensor:
        """Get log flow (log Z) for state."""
        return self.log_flow(state).squeeze(-1)
    
    def sample_trajectory(self, env, max_length: int = 100) -> List[np.ndarray]:
        """Sample trajectory using forward policy."""
        trajectory = [env.reset()]
        
        self.forward_policy.eval()
        with torch.no_grad():
            for step in range(max_length):
                state = torch.tensor(trajectory[-1], dtype=torch.float32).to(self.device)
                
                # Get valid actions mask
                valid_actions = env.get_valid_actions(trajectory[-1])
                mask = torch.zeros(self.action_dim, dtype=torch.bool).to(self.device)
                mask[valid_actions] = True
                
                # Sample action
                action = self.forward_policy.sample_action(state, mask)
                
                # Take step in environment
                next_state = env.step(trajectory[-1], action)
                trajectory.append(next_state)
                
                # Check if terminal
                if env.is_terminal(next_state):
                    break
        
        return trajectory
    
    def calculate_trajectory_balance_loss(self, trajectories: List[List[np.ndarray]], env) -> torch.Tensor:
        """
        Calculate trajectory balance (TB) loss.
        
        TB loss: |log Z(s_0) + log P_F(τ) - log P_B(τ) - log R(s_T)|^2
        """
        total_loss = 0.0
        num_trajectories = len(trajectories)
        
        for trajectory in trajectories:
            if len(trajectory) < 2:
                continue
                
            # Convert to tensors
            states = torch.tensor(np.array(trajectory), dtype=torch.float32).to(self.device)
            
            # Calculate forward log probabilities
            log_pf = 0.0
            for t in range(len(trajectory) - 1):
                state = states[t].unsqueeze(0)
                next_state = trajectory[t + 1]
                
                # Get action from state transition
                action = self._get_action_from_transition(trajectory[t], next_state, env)
                if action is not None:
                    # Get valid actions mask
                    valid_actions = env.get_valid_actions(trajectory[t])
                    mask = torch.zeros(self.action_dim, dtype=torch.bool).to(self.device)
                    mask[valid_actions] = True
                    
                    # Get forward probability
                    probs = self.forward_policy.get_action_probs(state, mask)
                    log_prob = torch.log(probs[0, action] + 1e-8)
                    log_pf += log_prob
            
            # Calculate backward log probabilities
            log_pb = 0.0
            for t in range(len(trajectory) - 1, 0, -1):
                state = states[t].unsqueeze(0)
                parent_state = states[t-1].unsqueeze(0)
                
                # Simplified backward probability
                back_probs = self.backward_policy.get_action_probs(state)
                log_pb += torch.log(back_probs.mean() + 1e-8)
            
            # Initial state flow
            initial_state = states[0].unsqueeze(0)
            log_z_init = self.get_log_flow(initial_state)
            if log_z_init.dim() > 0:
                log_z_init = log_z_init[0]
            
            # Terminal reward
            final_state = trajectory[-1]
            reward = env.get_reward(final_state)
            log_reward = torch.log(torch.tensor(reward + 1e-8, device=self.device))
            
            # Trajectory balance loss
            tb_error = log_z_init + log_pf - log_pb - log_reward
            loss = tb_error ** 2
            total_loss += loss
        
        return total_loss / num_trajectories if num_trajectories > 0 else torch.tensor(0.0, device=self.device)
    
    def _get_action_from_transition(self, state: np.ndarray, next_state: np.ndarray, env) -> Optional[int]:
        """Get action index from state transition."""
        # This is environment-specific - need to implement based on environment
        valid_actions = env.get_valid_actions(state)
        
        for action in valid_actions:
            test_next_state = env.step(state, action)
            if np.allclose(test_next_state, next_state):
                return action
        
        return None
    
    def train_step(self, env, trajectories: List[List[np.ndarray]], 
                  optimizer_forward, optimizer_backward, optimizer_flow) -> Dict[str, float]:
        """Perform one training step."""
        self.train()
        
        # Calculate trajectory balance loss
        tb_loss = self.calculate_trajectory_balance_loss(trajectories, env)
        
        # Update forward policy
        optimizer_forward.zero_grad()
        tb_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), 1.0)
        optimizer_forward.step()
        
        # Update backward policy
        optimizer_backward.zero_grad()
        tb_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), 1.0)
        optimizer_backward.step()
        
        # Update flow function
        optimizer_flow.zero_grad()
        tb_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.log_flow.parameters(), 1.0)
        optimizer_flow.step()
        
        self.step_count += 1
        
        # Track metrics
        metrics = {
            'tb_loss': tb_loss.item(),
            'step': self.step_count
        }
        
        self.metrics['tb_loss'].append(tb_loss.item())
        
        return metrics
    
    def sample(self, env, num_samples: int = 100) -> List[List[np.ndarray]]:
        """Sample multiple trajectories."""
        trajectories = []
        for _ in range(num_samples):
            trajectory = self.sample_trajectory(env)
            trajectories.append(trajectory)
        return trajectories


class DetailedBalance:
    """
    Detailed balance implementation for GFlowNets.
    Ensures flow consistency: sum of incoming flows = sum of outgoing flows.
    """
    
    @staticmethod
    def calculate_flow_consistency_loss(gfn: GFlowNetBase, states: torch.Tensor, 
                                      env, temperature: float = 1.0) -> torch.Tensor:
        """
        Calculate detailed balance loss for flow consistency.
        
        For each state s: F(s) = sum_s' F(s') * P_F(s|s') = sum_a P_F(a|s) * F(s->a)
        """
        total_loss = 0.0
        num_states = len(states)
        
        for state in states:
            state_tensor = state.unsqueeze(0)
            
            # Incoming flow (sum over parent states)
            incoming_flow = gfn.get_log_flow(state_tensor)
            
            # Outgoing flow (sum over actions)
            valid_actions = env.get_valid_actions(state.cpu().numpy())
            outgoing_flow = torch.tensor(0.0, device=state.device)
            
            for action in valid_actions:
                next_state = env.step(state.cpu().numpy(), action)
                next_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(state.device).unsqueeze(0)
                
                # P_F(a|s) * F(s->a)
                action_prob = gfn.get_forward_prob(state_tensor, action)
                next_flow = gfn.get_log_flow(next_state_tensor)
                
                outgoing_flow += action_prob * torch.exp(next_flow)
            
            # Flow consistency loss
            outgoing_log_flow = torch.log(outgoing_flow + 1e-8)
            flow_error = incoming_flow - outgoing_log_flow
            total_loss += flow_error ** 2
        
        return total_loss / num_states if num_states > 0 else torch.tensor(0.0)


class SubTrajectoryBalance:
    """
    Sub-trajectory balance for more stable training.
    Uses partial trajectories instead of full trajectories.
    """
    
    @staticmethod
    def calculate_subtraj_loss(gfn: GFlowNetBase, trajectory: List[np.ndarray], 
                             env, subtraj_length: int = 3) -> torch.Tensor:
        """Calculate sub-trajectory balance loss."""
        if len(trajectory) < subtraj_length + 1:
            return torch.tensor(0.0, device=gfn.device)
        
        total_loss = 0.0
        num_subtrajs = len(trajectory) - subtraj_length
        
        for start_idx in range(num_subtrajs):
            subtraj = trajectory[start_idx:start_idx + subtraj_length + 1]
            
            # Calculate forward probabilities for sub-trajectory
            log_pf = 0.0
            for t in range(len(subtraj) - 1):
                state = torch.tensor(subtraj[t], dtype=torch.float32).to(gfn.device).unsqueeze(0)
                action = gfn._get_action_from_transition(subtraj[t], subtraj[t+1], env)
                if action is not None:
                    prob = gfn.get_forward_prob(state, action)
                    log_pf += torch.log(prob + 1e-8)
            
            # Flow at start and end of sub-trajectory
            start_state = torch.tensor(subtraj[0], dtype=torch.float32).to(gfn.device).unsqueeze(0)
            end_state = torch.tensor(subtraj[-1], dtype=torch.float32).to(gfn.device).unsqueeze(0)
            
            log_flow_start = gfn.get_log_flow(start_state)
            log_flow_end = gfn.get_log_flow(end_state)
            
            # Sub-trajectory balance condition
            subtraj_error = log_flow_start + log_pf - log_flow_end
            total_loss += subtraj_error ** 2
        
        return total_loss / num_subtrajs
