"""
DATE-GFN implementation with distillation-aware fitness function.
implementation based on EGFN: https://github.com/zarifikram/egfn
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import copy
from collections import defaultdict
import random

from .gfn_base import GFlowNetBase, GFlowNetPolicy, DetailedBalance, SubTrajectoryBalance
from ..utils.utils import get_device


class DistillationAwareFitnessFunction:
    """
    Distillation-aware fitness function that balances reward and teachability.
    
    F_DA(ψ_j | θ*) = E[R(s_T)] - λ · E[D_KL(q_j(·|s_{1:t-1}) || P_F(·|s_{1:t-1}; θ*))]
    
    Based on evolutionary GFlowNet principles from EGFN repository.
    """
    
    def __init__(self, teachability_weight: float = 0.1, temperature: float = 1.0):
        self.teachability_weight = teachability_weight
        self.temperature = temperature
        
    def __call__(self, critic_gfn: GFlowNetBase, 
                 student_gfn: GFlowNetBase,
                 environment, 
                 num_trajectories: int = 50) -> float:
        """
        Evaluate fitness of a critic GFlowNet given current student capabilities.
        
        Args:
            critic_gfn: Critic GFlowNet being evaluated
            student_gfn: Current student GFlowNet model
            environment: Environment to sample trajectories
            num_trajectories: Number of trajectories to sample for evaluation
            
        Returns:
            Fitness score combining reward and teachability
        """
        # Sample trajectories using the critic
        trajectories = critic_gfn.sample(environment, num_trajectories)
        
        if not trajectories:
            return -float('inf')
        
        # Calculate reward component
        rewards = []
        for traj in trajectories:
            if len(traj) > 0:
                reward = environment.get_reward(traj[-1])
                rewards.append(reward)
        
        if not rewards:
            return -float('inf')
            
        reward_component = np.mean(rewards)
        
        # Calculate teachability component (KL divergence between critic and student)
        teachability_penalty = self._calculate_teachability_penalty(
            trajectories, critic_gfn, student_gfn, environment
        )
        
        # Combine components
        fitness = reward_component - self.teachability_weight * teachability_penalty
        
        return fitness
    
    def _create_critic_from_params(self, params: torch.Tensor, reference_model: nn.Module) -> nn.Module:
        """Create a critic model from parameter tensor."""
        # Create a copy of the reference model architecture
        critic = copy.deepcopy(reference_model)
        
        # Load parameters
        param_dict = {}
        start_idx = 0
        for name, param in critic.named_parameters():
            param_size = param.numel()
            param_dict[name] = params[start_idx:start_idx + param_size].reshape(param.shape)
            start_idx += param_size
        
        # Set parameters
        with torch.no_grad():
            for name, param in critic.named_parameters():
                param.copy_(param_dict[name])
        
        return critic
    
    def _sample_trajectories(self, critic: nn.Module, environment, num_trajectories: int) -> List[List]:
        """Sample trajectories using the critic as guidance."""
        trajectories = []
        
        critic.eval()
        with torch.no_grad():
            for _ in range(num_trajectories):
                trajectory = environment.sample_trajectory_with_critic(critic)
                trajectories.append(trajectory)
        
        return trajectories
    
    def _calculate_teachability_penalty(self, trajectories: List[List[np.ndarray]], 
                                      critic_gfn: GFlowNetBase, 
                                      student_gfn: GFlowNetBase, 
                                      environment) -> float:
        """
        Calculate the teachability penalty using GFlowNet policies.
        
        KL divergence between critic and student action distributions.
        """
        total_kl = 0.0
        num_steps = 0
        
        critic_gfn.eval()
        student_gfn.eval()
        
        with torch.no_grad():
            for trajectory in trajectories:
                if len(trajectory) < 2:
                    continue
                    
                for t in range(len(trajectory) - 1):
                    state = torch.tensor(trajectory[t], dtype=torch.float32).to(critic_gfn.device)
                    
                    # Get valid actions mask
                    valid_actions = environment.get_valid_actions(trajectory[t])
                    if not valid_actions:
                        continue
                        
                    mask = torch.zeros(critic_gfn.action_dim, dtype=torch.bool).to(critic_gfn.device)
                    mask[valid_actions] = True
                    
                    # Get action distributions from both policies
                    critic_probs = critic_gfn.forward_policy.get_action_probs(state.unsqueeze(0), mask)
                    student_probs = student_gfn.forward_policy.get_action_probs(state.unsqueeze(0), mask)
                    
                    # Calculate KL divergence D_KL(critic || student)
                    critic_probs = critic_probs.squeeze(0)
                    student_probs = student_probs.squeeze(0)
                    
                    # Add small epsilon for numerical stability
                    eps = 1e-8
                    critic_probs = critic_probs + eps
                    student_probs = student_probs + eps
                    
                    # Normalize
                    critic_probs = critic_probs / critic_probs.sum()
                    student_probs = student_probs / student_probs.sum()
                    
                    # KL divergence
                    kl_div = torch.sum(critic_probs * torch.log(critic_probs / student_probs))
                    total_kl += kl_div.item()
                    num_steps += 1
        
        return total_kl / num_steps if num_steps > 0 else 0.0


class EvolutionaryAlgorithm:
    """
    Evolutionary algorithm for evolving GFlowNet critics.
    Based on EGFN evolutionary principles.
    """
    
    def __init__(self, population_size: int = 50, 
                 elite_ratio: float = 0.25,
                 mutation_prob: float = 0.1,
                 crossover_prob: float = 0.8,
                 tournament_size: int = 4):
        self.population_size = population_size
        self.elite_ratio = elite_ratio
        self.mutation_prob = mutation_prob
        self.crossover_prob = crossover_prob
        self.tournament_size = tournament_size
        self.num_elites = int(population_size * elite_ratio)
        
    def initialize_population(self, state_dim: int, action_dim: int, 
                            hidden_dim: int = 256, device: torch.device = None) -> List[GFlowNetBase]:
        """Initialize random population of GFlowNet critics."""
        population = []
        
        for _ in range(self.population_size):
            # Create a new GFlowNet with critic learning rate
            critic_lr = 1e-4  # Critic learning rate
            gfn = GFlowNetBase(state_dim, action_dim, hidden_dim, lr=critic_lr, device=device)
            
            # Add random noise to parameters for diversity
            with torch.no_grad():
                for param in gfn.parameters():
                    param.add_(torch.randn_like(param) * 0.1)
            
            population.append(gfn)
        
        return population
    
    def evaluate_population(self, population: List[GFlowNetBase],
                          fitness_function: DistillationAwareFitnessFunction,
                          student_gfn: GFlowNetBase,
                          environment) -> torch.Tensor:
        """Evaluate fitness for entire population."""
        fitness_scores = torch.zeros(self.population_size)
        
        for i, critic_gfn in enumerate(population):
            try:
                fitness_scores[i] = fitness_function(critic_gfn, student_gfn, environment)
            except Exception as e:
                # If evaluation fails, assign very low fitness
                fitness_scores[i] = -float('inf')
        
        return fitness_scores
    
    def select_and_reproduce(self, population: List[GFlowNetBase], 
                           fitness_scores: torch.Tensor) -> List[GFlowNetBase]:
        """Select elites and generate offspring through crossover and mutation."""
        # Sort by fitness (descending)
        sorted_indices = torch.argsort(fitness_scores, descending=True)
        
        # Select elites
        elite_indices = sorted_indices[:self.num_elites]
        elites = [population[i] for i in elite_indices]
        
        # Generate new population
        new_population = []
        
        # Keep elites
        for elite in elites:
            new_population.append(copy.deepcopy(elite))
        
        # Generate offspring through tournament selection, crossover and mutation
        while len(new_population) < self.population_size:
            # Tournament selection for parents
            parent1 = self._tournament_selection(population, fitness_scores)
            parent2 = self._tournament_selection(population, fitness_scores)
            
            # Create offspring through crossover and mutation
            offspring = self._crossover_and_mutate(parent1, parent2)
            new_population.append(offspring)
        
        return new_population
    
    def _tournament_selection(self, population: List[GFlowNetBase], 
                            fitness_scores: torch.Tensor) -> GFlowNetBase:
        """Tournament selection with specified tournament size."""
        # Randomly select tournament_size individuals
        tournament_indices = random.sample(range(len(population)), self.tournament_size)
        
        # Find the best individual in the tournament
        best_idx = max(tournament_indices, key=lambda i: fitness_scores[i].item())
        
        return population[best_idx]
    
    def _crossover_and_mutate(self, parent1: GFlowNetBase, parent2: GFlowNetBase) -> GFlowNetBase:
        """Create offspring through crossover and mutation."""
        # Create offspring as copy of parent1
        offspring = copy.deepcopy(parent1)
        
        # Single-point crossover
        if random.random() < self.crossover_prob:
            with torch.no_grad():
                # Get all parameters as flat vectors
                parent1_params = []
                parent2_params = []
                offspring_params = []
                
                for p1, p2, p_off in zip(parent1.parameters(), parent2.parameters(), offspring.parameters()):
                    parent1_params.append(p1.flatten())
                    parent2_params.append(p2.flatten())
                    offspring_params.append(p_off.flatten())
                
                # Concatenate all parameters
                parent1_flat = torch.cat(parent1_params)
                parent2_flat = torch.cat(parent2_params)
                offspring_flat = torch.cat(offspring_params)
                
                # Single crossover point
                crossover_point = random.randint(1, len(parent1_flat) - 1)
                
                # Apply crossover
                offspring_flat[:crossover_point] = parent1_flat[:crossover_point]
                offspring_flat[crossover_point:] = parent2_flat[crossover_point:]
                
                # Reshape back to original parameter shapes
                start_idx = 0
                for p_off, original_shape in zip(offspring.parameters(), [p.shape for p in offspring.parameters()]):
                    param_size = p_off.numel()
                    p_off.copy_(offspring_flat[start_idx:start_idx + param_size].view(original_shape))
                    start_idx += param_size
        
        # Polynomial mutation
        with torch.no_grad():
            for param in offspring.parameters():
                if random.random() < self.mutation_prob:
                    # Polynomial mutation with distribution index = 20
                    eta = 20.0
                    u = torch.rand_like(param)
                    
                    # Polynomial mutation formula
                    mask1 = u <= 0.5
                    mask2 = u > 0.5
                    
                    delta1 = (2 * u) ** (1 / (eta + 1)) - 1
                    delta2 = 1 - (2 * (1 - u)) ** (1 / (eta + 1))
                    
                    delta = torch.where(mask1, delta1, delta2)
                    param.add_(delta * 0.1)  # Scale factor
        
        return offspring
    
    def get_best_individual(self, population: List[GFlowNetBase], 
                          fitness_scores: torch.Tensor) -> Tuple[GFlowNetBase, float]:
        """Get the best individual from the population."""
        best_idx = torch.argmax(fitness_scores)
        return population[best_idx], fitness_scores[best_idx].item()


class DATEGFN(nn.Module):
    """
    Distillation-Aware Twisted Evolutionary GFlowNet.
    
    Combines evolutionary algorithm for critic evolution with GFlowNet student training.
    Based on EGFN principles: https://github.com/zarifikram/egfn
    """
    
    def __init__(self, 
                 state_dim: int,
                 action_dim: int,
                 hidden_dim: int = 256,
                 num_layers: int = 3,
                 population_size: int = 50,
                 elite_ratio: float = 0.25,
                 teachability_weight: float = 0.1,
                 student_updates_per_cycle: int = 100,
                 device: Optional[torch.device] = None):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.population_size = population_size
        self.teachability_weight = teachability_weight
        self.student_updates_per_cycle = student_updates_per_cycle
        self.device = device or get_device()
        
        # Student GFlowNet - this is the main model being trained
        self.student_gfn = GFlowNetBase(
            state_dim, action_dim, hidden_dim, num_layers, device=self.device
        )
        
        # Evolutionary algorithm for critics
        self.ea = EvolutionaryAlgorithm(population_size, elite_ratio)
        
        # Distillation-aware fitness function
        self.fitness_function = DistillationAwareFitnessFunction(teachability_weight)
        
        # Population storage
        self.critic_population = None
        self.population_fitness = None
        
        # Training state
        self.generation = 0
        self.total_steps = 0
        
        # Metrics tracking
        self.metrics_history = defaultdict(list)
        
        # Replay buffer for offline training (like EGFN)
        self.replay_buffer = []
        self.buffer_size = 10000
        
        # Initialize optimizers for student GFlowNet
        self.student_optimizers = self._create_optimizers()
    
    def _create_optimizers(self):
        """Create optimizers for student GFlowNet."""
        student_lr = 5e-4  # Student learning rate
        return {
            'forward': torch.optim.Adam(self.student_gfn.forward_policy.parameters(), lr=student_lr),
            'backward': torch.optim.Adam(self.student_gfn.backward_policy.parameters(), lr=student_lr),
            'flow': torch.optim.Adam(self.student_gfn.log_flow.parameters(), lr=student_lr)
        }
    
    def initialize_population(self):
        """Initialize the critic population with GFlowNet instances."""
        self.critic_population = self.ea.initialize_population(
            self.state_dim, self.action_dim, self.hidden_dim, self.device
        )
        self.population_fitness = torch.zeros(self.population_size)
        
    def evolutionary_phase(self, environment):
        """Run evolutionary phase to evolve critic GFlowNets."""
        if self.critic_population is None:
            self.initialize_population()
        
        # Evaluate current population using distillation-aware fitness
        self.population_fitness = self.ea.evaluate_population(
            self.critic_population, 
            self.fitness_function,
            self.student_gfn,
            environment
        )
        
        # Store trajectories from best critics in replay buffer
        self._update_replay_buffer(environment)
        
        # Select and reproduce
        self.critic_population = self.ea.select_and_reproduce(
            self.critic_population, 
            self.population_fitness
        )
        
        self.generation += 1
        
        # Track metrics
        self.metrics_history['best_fitness'].append(self.population_fitness.max().item())
        self.metrics_history['mean_fitness'].append(self.population_fitness.mean().item())
        self.metrics_history['fitness_std'].append(self.population_fitness.std().item())
        
    def _update_replay_buffer(self, environment):
        """Update replay buffer with trajectories from elite critics."""
        # Get best critics
        best_indices = torch.argsort(self.population_fitness, descending=True)[:3]
        
        for idx in best_indices:
            if self.population_fitness[idx] > -float('inf'):
                critic_gfn = self.critic_population[idx]
                # Sample trajectories from this critic
                trajectories = critic_gfn.sample(environment, num_samples=10)
                
                # Add to replay buffer
                for traj in trajectories:
                    if len(traj) > 1:
                        self.replay_buffer.append(traj)
        
        # Maintain buffer size
        if len(self.replay_buffer) > self.buffer_size:
            self.replay_buffer = self.replay_buffer[-self.buffer_size:]
        
    def distillation_phase(self, environment):
        """
        Run distillation phase to update student GFlowNet.
        Uses both online trajectories and replay buffer (offline) like EGFN.
        """
        if not self.critic_population:
            return
        
        for step in range(self.student_updates_per_cycle):
            # Get training trajectories (mix of online and offline)
            online_trajectories = self.student_gfn.sample(environment, num_samples=16)
            offline_trajectories = random.sample(
                self.replay_buffer, min(16, len(self.replay_buffer))
            ) if self.replay_buffer else []
            
            all_trajectories = online_trajectories + offline_trajectories
            
            if not all_trajectories:
                continue
            
            # Calculate standard GFlowNet trajectory balance loss
            tb_loss = self.student_gfn.calculate_trajectory_balance_loss(all_trajectories, environment)
            
            # Calculate distillation loss from elite critics
            distill_loss = self._calculate_distillation_loss_from_critics(all_trajectories, environment)
            
            # Combine losses with teachability weighting
            total_loss = tb_loss + self.teachability_weight * distill_loss
            
            # Update student GFlowNet
            self.student_optimizers['forward'].zero_grad()
            self.student_optimizers['backward'].zero_grad() 
            self.student_optimizers['flow'].zero_grad()
            
            total_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.student_gfn.parameters(), 1.0)
            
            self.student_optimizers['forward'].step()
            self.student_optimizers['backward'].step()
            self.student_optimizers['flow'].step()
            
            self.total_steps += 1
            
            # Track metrics
            if step % 20 == 0:
                self.metrics_history['tb_loss'].append(tb_loss.item())
                self.metrics_history['distill_loss'].append(distill_loss.item())
                self.metrics_history['total_loss'].append(total_loss.item())
    
    def _calculate_distillation_loss_from_critics(self, trajectories: List[List[np.ndarray]], 
                                                environment) -> torch.Tensor:
        """Calculate distillation loss from elite critics."""
        if not self.critic_population or len(trajectories) == 0:
            return torch.tensor(0.0, device=self.device)
        
        # Get best critics
        best_indices = torch.argsort(self.population_fitness, descending=True)[:3]
        elite_critics = [self.critic_population[i] for i in best_indices 
                        if self.population_fitness[i] > -float('inf')]
        
        if not elite_critics:
            return torch.tensor(0.0, device=self.device)
        
        total_loss = 0.0
        num_valid_steps = 0
        
        for trajectory in trajectories:
            if len(trajectory) < 2:
                continue
                
            for t in range(len(trajectory) - 1):
                state = torch.tensor(trajectory[t], dtype=torch.float32).to(self.device)
                
                # Get valid actions
                valid_actions = environment.get_valid_actions(trajectory[t])
                if not valid_actions:
                    continue
                
                mask = torch.zeros(self.action_dim, dtype=torch.bool).to(self.device)
                mask[valid_actions] = True
                
                # Student policy
                student_probs = self.student_gfn.forward_policy.get_action_probs(
                    state.unsqueeze(0), mask
                ).squeeze(0)
                
                # Average critic guidance
                critic_probs_sum = torch.zeros_like(student_probs)
                num_critics = 0
                
                for critic in elite_critics:
                    try:
                        critic_probs = critic.forward_policy.get_action_probs(
                            state.unsqueeze(0), mask
                        ).squeeze(0)
                        critic_probs_sum += critic_probs
                        num_critics += 1
                    except:
                        continue
                
                if num_critics == 0:
                    continue
                
                # Average critic probabilities
                avg_critic_probs = critic_probs_sum / num_critics
                
                # KL divergence loss: D_KL(critic || student)
                eps = 1e-8
                student_probs = student_probs + eps
                avg_critic_probs = avg_critic_probs + eps
                
                # Normalize
                student_probs = student_probs / student_probs.sum()
                avg_critic_probs = avg_critic_probs / avg_critic_probs.sum()
                
                kl_loss = torch.sum(avg_critic_probs * torch.log(avg_critic_probs / student_probs))
                total_loss += kl_loss
                num_valid_steps += 1
        
        return total_loss / num_valid_steps if num_valid_steps > 0 else torch.tensor(0.0, device=self.device)
    
    def train_step(self, environment, optimizer_forward=None, optimizer_backward=None) -> Dict[str, float]:
        """
        Perform one complete DATE-GFN training step.
        
        This is the main training interface that combines:
        1. Evolutionary phase (evolve critics)
        2. Distillation phase (train student)
        """
        # Evolutionary phase
        self.evolutionary_phase(environment)
        
        # Distillation phase 
        self.distillation_phase(environment)
        
        return self.get_metrics()
    
    def sample(self, environment, num_samples: int = 100) -> List[List[np.ndarray]]:
        """Sample trajectories using the trained student GFlowNet."""
        return self.student_gfn.sample(environment, num_samples)
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get current training metrics."""
        metrics = {}
        
        for key, values in self.metrics_history.items():
            if values:
                metrics[f"{key}_latest"] = values[-1]
                metrics[f"{key}_mean"] = np.mean(values[-10:])  # Recent average
        
        metrics['generation'] = self.generation
        metrics['total_steps'] = self.total_steps
        metrics['replay_buffer_size'] = len(self.replay_buffer)
        
        if self.population_fitness is not None:
            metrics['population_diversity'] = self._calculate_population_diversity()
        
        return metrics
    
    def _calculate_population_diversity(self) -> float:
        """Calculate diversity of the current population."""
        if not self.critic_population or len(self.critic_population) < 2:
            return 0.0
        
        # Calculate parameter diversity between GFlowNets
        total_distance = 0.0
        num_pairs = 0
        
        for i in range(len(self.critic_population)):
            for j in range(i + 1, len(self.critic_population)):
                # Calculate parameter distance between two GFlowNets
                distance = 0.0
                param_count = 0
                
                for p1, p2 in zip(self.critic_population[i].parameters(), 
                                self.critic_population[j].parameters()):
                    distance += torch.norm(p1 - p2).item()
                    param_count += 1
                
                if param_count > 0:
                    total_distance += distance / param_count
                num_pairs += 1
        
        return total_distance / num_pairs if num_pairs > 0 else 0.0
    
    def save_checkpoint(self, filepath: str):
        """Save DATE-GFN checkpoint with all components."""
        checkpoint = {
            'student_gfn_state_dict': self.student_gfn.state_dict(),
            'student_optimizers': {k: v.state_dict() for k, v in self.student_optimizers.items()},
            'population_fitness': self.population_fitness,
            'generation': self.generation,
            'total_steps': self.total_steps,
            'replay_buffer': self.replay_buffer,
            'metrics_history': dict(self.metrics_history),
            'config': {
                'state_dim': self.state_dim,
                'action_dim': self.action_dim,
                'hidden_dim': self.hidden_dim,
                'population_size': self.population_size,
                'teachability_weight': self.teachability_weight
            }
        }
        torch.save(checkpoint, filepath)
    
    def load_checkpoint(self, filepath: str):
        """Load DATE-GFN checkpoint."""
        checkpoint = torch.load(filepath, map_location=self.device)
        
        self.student_gfn.load_state_dict(checkpoint['student_gfn_state_dict'])
        
        for k, v in checkpoint.get('student_optimizers', {}).items():
            if k in self.student_optimizers:
                self.student_optimizers[k].load_state_dict(v)
        
        self.population_fitness = checkpoint.get('population_fitness')
        self.generation = checkpoint.get('generation', 0)
        self.total_steps = checkpoint.get('total_steps', 0)
        self.replay_buffer = checkpoint.get('replay_buffer', [])
        self.metrics_history = defaultdict(list, checkpoint.get('metrics_history', {}))
        
        # Re-initialize population (can't easily save/load GFlowNet instances)
        if self.population_fitness is not None:
            self.initialize_population()


class CriticNetwork(nn.Module):
    """Standalone critic network for evaluation and guidance."""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.network(state)
    
    def get_action_logits(self, state: torch.Tensor) -> torch.Tensor:
        """Get action logits for a given state."""
        return self.forward(state)
    
    def get_action_probabilities(self, state: torch.Tensor) -> torch.Tensor:
        """Get action probabilities for a given state."""
        logits = self.forward(state)
        return F.softmax(logits, dim=-1)
