"""
Baseline methods for comparison with DATE-GFN.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
import copy
from collections import defaultdict, deque
import random

from ..core.gfn_base import GFlowNetBase, GFlowNetPolicy
from .utils import get_device


class GFNBaseline(GFlowNetBase):
    """Standard GFlowNet with Trajectory Balance (TB) objective."""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, num_layers: int = 3):
        super().__init__(state_dim, action_dim, hidden_dim, num_layers)
        
        self.metrics_history = defaultdict(list)
        self.total_steps = 0
        
    
    def train_step(self, environment, optimizer, batch_size: int = 32) -> Dict[str, float]:
        """Perform one training step."""
        # Sample batch of trajectories
        trajectories = []
        for _ in range(batch_size):
            trajectory = self.sample(environment, 1)[0]  # Use inherited sample method
            trajectories.append(trajectory)
        
        # Calculate loss using inherited method
        loss = self.calculate_trajectory_balance_loss(trajectories, environment)
        
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        self.total_steps += 1
        
        # Track metrics
        self.metrics_history['tb_loss'].append(loss.item())
        
        return {'tb_loss': loss.item(), 'total_steps': self.total_steps}


class EGFNBaseline(nn.Module):
    """Evolution Guided GFlowNet (EGFN) baseline."""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, 
                 population_size: int = 50, elite_ratio: float = 0.25):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.population_size = population_size
        self.elite_ratio = elite_ratio
        self.device = get_device()
        
        # Student GFlowNet - now using the GFlowNet implementation
        self.student_gfn = GFlowNetBase(state_dim, action_dim, hidden_dim).to(self.device)
        
        # Population of GFlowNet instances
        self.population = None
        self.population_fitness = None
        self.generation = 0
        
        # Replay buffer for offline training
        self.replay_buffer = []
        
        self.metrics_history = defaultdict(list)
        
    def initialize_population(self):
        """Initialize population of GFlowNet instances."""
        self.population = []
        for _ in range(self.population_size):
            # Create a new GFlowNet instance
            gfn_instance = GFlowNetBase(self.state_dim, self.action_dim).to(self.device)
            
            # Add random noise to parameters for diversity
            with torch.no_grad():
                for param in gfn_instance.parameters():
                    param.add_(torch.randn_like(param) * 0.1)
            
            self.population.append(gfn_instance)
        
        self.population_fitness = torch.zeros(self.population_size)
    
    def evaluate_population(self, environment, num_trajectories: int = 100):
        """Evaluate fitness of all GFlowNet instances in the population."""
        for i, gfn_instance in enumerate(self.population):
            total_reward = 0.0
            gfn_instance.eval()
            
            with torch.no_grad():
                for _ in range(num_trajectories):
                    # Sample trajectory using the GFlowNet's sample method
                    trajectories = gfn_instance.sample(environment, 1)
                    trajectory = trajectories[0]
                    reward = environment.get_reward(trajectory[-1])
                    total_reward += reward
            
            self.population_fitness[i] = total_reward / num_trajectories
    
    
    def evolve_population(self):
        """Evolve the population using genetic algorithm."""
        # Select elites
        num_elites = int(self.population_size * self.elite_ratio)
        elite_indices = torch.argsort(self.population_fitness, descending=True)[:num_elites]
        
        new_population = []
        
        # Keep elites
        for idx in elite_indices:
            new_population.append(copy.deepcopy(self.population[idx]))
        
        # Generate offspring
        while len(new_population) < self.population_size:
            # Select two random elites as parents
            parent1_idx = elite_indices[random.randint(0, num_elites - 1)]
            parent2_idx = elite_indices[random.randint(0, num_elites - 1)]
            
            # Create offspring through crossover and mutation
            offspring = self._crossover_and_mutate(
                self.population[parent1_idx], 
                self.population[parent2_idx]
            )
            new_population.append(offspring)
        
        self.population = new_population
        self.generation += 1
    
    def _crossover_and_mutate(self, parent1: GFlowNetBase, parent2: GFlowNetBase) -> GFlowNetBase:
        """Create offspring through crossover and mutation of GFlowNet instances."""
        offspring = GFlowNetBase(self.state_dim, self.action_dim).to(self.device)
        
        with torch.no_grad():
            for p1_param, p2_param, offspring_param in zip(
                parent1.parameters(), parent2.parameters(), offspring.parameters()
            ):
                # Crossover: random blend of parents
                alpha = torch.rand_like(p1_param)
                offspring_param.copy_(alpha * p1_param + (1 - alpha) * p2_param)
                
                # Mutation: add noise
                mutation = torch.randn_like(offspring_param) * 0.01
                offspring_param.add_(mutation)
        
        return offspring
    
    def distill_from_population(self, environment, optimizer, num_steps: int = 1000):
        """Distill knowledge from evolved population to student GFlowNet."""
        # Get best GFlowNet instances
        best_indices = torch.argsort(self.population_fitness, descending=True)[:5]
        best_gfns = [self.population[idx] for idx in best_indices]
        
        # Update replay buffer with trajectories from elite GFlowNets
        self._update_replay_buffer(best_gfns, environment)
        
        for step in range(num_steps):
            # Sample online trajectories from student
            online_trajectories = self.student_gfn.sample(environment, 8)
            
            # Sample offline trajectories from replay buffer
            offline_trajectories = random.sample(self.replay_buffer, min(8, len(self.replay_buffer))) if self.replay_buffer else []
            
            # Combine trajectories
            all_trajectories = online_trajectories + offline_trajectories
            
            # Calculate trajectory balance loss
            tb_loss = self.student_gfn.calculate_trajectory_balance_loss(all_trajectories, environment)
            
            # Calculate distillation loss from best GFlowNets
            distill_loss = self._calculate_distillation_loss_from_gfns(all_trajectories, best_gfns, environment)
            
            # Combined loss
            total_loss = tb_loss + 0.1 * distill_loss
            
            # Update student
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                self.metrics_history['distill_loss'].append(distill_loss.item())
                self.metrics_history['tb_loss'].append(tb_loss.item())
    
    def _update_replay_buffer(self, elite_gfns: List[GFlowNetBase], environment, max_buffer_size: int = 1000):
        """Update replay buffer with trajectories from elite GFlowNets."""
        for gfn in elite_gfns:
            trajectories = gfn.sample(environment, 10)  # Sample 10 trajectories per elite
            self.replay_buffer.extend(trajectories)
        
        # Keep buffer size manageable
        if len(self.replay_buffer) > max_buffer_size:
            self.replay_buffer = self.replay_buffer[-max_buffer_size:]
    
    def _calculate_distillation_loss_from_gfns(self, trajectories: List, teacher_gfns: List[GFlowNetBase], environment) -> torch.Tensor:
        """Calculate distillation loss from teacher GFlowNets."""
        total_loss = 0.0
        
        for trajectory in trajectories:
            for t in range(len(trajectory) - 1):
                state = torch.tensor(trajectory[t], dtype=torch.float32).to(self.device)
                
                # Get valid actions for this state
                valid_actions = environment.get_valid_actions(trajectory[t])
                action_mask = torch.zeros(self.action_dim, dtype=torch.bool, device=self.device)
                action_mask[valid_actions] = True
                
                # Student action probabilities
                student_probs = self.student_gfn.get_action_probs(state, action_mask)
                student_log_probs = torch.log(student_probs + 1e-8)
                
                # Average teacher action probabilities
                teacher_probs = torch.zeros_like(student_probs)
                for teacher_gfn in teacher_gfns:
                    teacher_gfn.eval()
                    with torch.no_grad():
                        teacher_action_probs = teacher_gfn.get_action_probs(state, action_mask)
                        teacher_probs += teacher_action_probs
                teacher_probs /= len(teacher_gfns)
                
                # KL divergence between student and average teacher
                kl_loss = F.kl_div(student_log_probs, teacher_probs, reduction='sum')
                total_loss += kl_loss
        
        return total_loss / len(trajectories) if trajectories else torch.tensor(0.0, device=self.device)
    
    def train_step(self, environment, optimizer) -> Dict[str, float]:
        """Perform one complete training step."""
        if self.population is None:
            self.initialize_population()
        
        # Evolutionary phase
        self.evaluate_population(environment)
        self.evolve_population()
        
        # Distillation phase
        self.distill_from_population(environment, optimizer)
        
        # Track metrics
        metrics = {
            'best_fitness': self.population_fitness.max().item(),
            'mean_fitness': self.population_fitness.mean().item(),
            'generation': self.generation
        }
        
        self.metrics_history['best_fitness'].append(metrics['best_fitness'])
        self.metrics_history['mean_fitness'].append(metrics['mean_fitness'])
        
        return metrics
    
    def sample(self, environment, num_samples: int = 100) -> List:
        """Sample trajectories using the trained student GFlowNet."""
        return self.student_gfn.sample(environment, num_samples)


class SACBaseline(nn.Module):
    """Soft Actor-Critic adapted for discrete action spaces."""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = get_device()
        
        # Actor network (policy)
        self.actor = self._create_network(state_dim, action_dim, hidden_dim).to(self.device)
        
        # Critic networks (Q-functions)
        self.critic1 = self._create_network(state_dim + action_dim, 1, hidden_dim).to(self.device)
        self.critic2 = self._create_network(state_dim + action_dim, 1, hidden_dim).to(self.device)
        
        # Target critics
        self.target_critic1 = copy.deepcopy(self.critic1)
        self.target_critic2 = copy.deepcopy(self.critic2)
        
        # Temperature parameter
        self.log_alpha = torch.tensor(0.0, requires_grad=True, device=self.device)
        
        # Replay buffer
        self.replay_buffer = deque(maxlen=100000)
        
        self.metrics_history = defaultdict(list)
        self.total_steps = 0
        
    def _create_network(self, input_dim: int, output_dim: int, hidden_dim: int) -> nn.Module:
        """Create neural network."""
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def get_action(self, state: torch.Tensor, deterministic: bool = False) -> int:
        """Get action from policy."""
        self.actor.eval()
        with torch.no_grad():
            logits = self.actor(state)
            
            if deterministic:
                action = torch.argmax(logits).item()
            else:
                probs = F.softmax(logits, dim=-1)
                action = torch.multinomial(probs, 1).item()
        
        return action
    
    def sample_trajectory(self, environment, max_length: int = 100) -> List:
        """Sample trajectory using current policy."""
        trajectory = [environment.reset()]
        
        for _ in range(max_length):
            state = torch.tensor(trajectory[-1], dtype=torch.float32).to(self.device)
            action = self.get_action(state)
            next_state = environment.step(trajectory[-1], action)
            
            # Store transition in replay buffer
            reward = environment.get_reward(next_state) if environment.is_terminal(next_state) else 0.0
            self.replay_buffer.append((trajectory[-1], action, reward, next_state, environment.is_terminal(next_state)))
            
            trajectory.append(next_state)
            
            if environment.is_terminal(next_state):
                break
        
        return trajectory
    
    def update_critics(self, batch, target_q_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update critic networks."""
        states, actions, rewards, next_states, dones = batch
        
        # One-hot encode actions for discrete case
        actions_onehot = F.one_hot(actions, self.action_dim).float()
        
        # Current Q-values
        q1_values = self.critic1(torch.cat([states, actions_onehot], dim=1))
        q2_values = self.critic2(torch.cat([states, actions_onehot], dim=1))
        
        # Critic losses
        critic1_loss = F.mse_loss(q1_values.squeeze(), target_q_values)
        critic2_loss = F.mse_loss(q2_values.squeeze(), target_q_values)
        
        return critic1_loss, critic2_loss
    
    def update_actor(self, states: torch.Tensor) -> torch.Tensor:
        """Update actor network."""
        # Policy logits
        logits = self.actor(states)
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Q-values for all actions
        q1_all = []
        q2_all = []
        
        for action in range(self.action_dim):
            action_onehot = F.one_hot(torch.tensor([action] * states.size(0)), self.action_dim).float().to(self.device)
            q1_val = self.critic1(torch.cat([states, action_onehot], dim=1))
            q2_val = self.critic2(torch.cat([states, action_onehot], dim=1))
            q1_all.append(q1_val.squeeze())
            q2_all.append(q2_val.squeeze())
        
        q1_all = torch.stack(q1_all, dim=1)
        q2_all = torch.stack(q2_all, dim=1)
        q_all = torch.min(q1_all, q2_all)
        
        # Actor loss (policy gradient with entropy regularization)
        alpha = torch.exp(self.log_alpha)
        actor_loss = (probs * (alpha * log_probs - q_all)).sum(dim=1).mean()
        
        return actor_loss
    
    def train_step(self, environment, actor_optimizer, critic1_optimizer, critic2_optimizer, 
                  alpha_optimizer, batch_size: int = 32) -> Dict[str, float]:
        """Perform one training step."""
        # Sample trajectory and add to buffer
        self.sample_trajectory(environment)
        
        if len(self.replay_buffer) < batch_size:
            return {'buffer_size': len(self.replay_buffer)}
        
        # Sample batch from replay buffer
        batch_data = random.sample(self.replay_buffer, batch_size)
        
        states = torch.tensor([d[0] for d in batch_data], dtype=torch.float32).to(self.device)
        actions = torch.tensor([d[1] for d in batch_data], dtype=torch.long).to(self.device)
        rewards = torch.tensor([d[2] for d in batch_data], dtype=torch.float32).to(self.device)
        next_states = torch.tensor([d[3] for d in batch_data], dtype=torch.float32).to(self.device)
        dones = torch.tensor([d[4] for d in batch_data], dtype=torch.bool).to(self.device)
        
        # Calculate target Q-values
        with torch.no_grad():
            next_logits = self.actor(next_states)
            next_probs = F.softmax(next_logits, dim=-1)
            next_log_probs = F.log_softmax(next_logits, dim=-1)
            
            # Next Q-values for all actions
            next_q1_all = []
            next_q2_all = []
            
            for action in range(self.action_dim):
                action_onehot = F.one_hot(torch.tensor([action] * next_states.size(0)), self.action_dim).float().to(self.device)
                next_q1 = self.target_critic1(torch.cat([next_states, action_onehot], dim=1))
                next_q2 = self.target_critic2(torch.cat([next_states, action_onehot], dim=1))
                next_q1_all.append(next_q1.squeeze())
                next_q2_all.append(next_q2.squeeze())
            
            next_q1_all = torch.stack(next_q1_all, dim=1)
            next_q2_all = torch.stack(next_q2_all, dim=1)
            next_q_all = torch.min(next_q1_all, next_q2_all)
            
            alpha = torch.exp(self.log_alpha)
            next_v = (next_probs * (next_q_all - alpha * next_log_probs)).sum(dim=1)
            
            target_q = rewards + (~dones).float() * 0.99 * next_v
        
        # Update critics
        critic1_loss, critic2_loss = self.update_critics((states, actions, rewards, next_states, dones), target_q)
        
        critic1_optimizer.zero_grad()
        critic1_loss.backward()
        critic1_optimizer.step()
        
        critic2_optimizer.zero_grad()
        critic2_loss.backward()
        critic2_optimizer.step()
        
        # Update actor
        actor_loss = self.update_actor(states)
        
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()
        
        # Update temperature
        alpha_loss = -(self.log_alpha * (next_log_probs.detach() + self.action_dim)).mean()
        
        alpha_optimizer.zero_grad()
        alpha_loss.backward()
        alpha_optimizer.step()
        
        # Update target networks
        tau = 0.005
        for target_param, param in zip(self.target_critic1.parameters(), self.critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        
        for target_param, param in zip(self.target_critic2.parameters(), self.critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        
        self.total_steps += 1
        
        # Track metrics
        metrics = {
            'critic1_loss': critic1_loss.item(),
            'critic2_loss': critic2_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': torch.exp(self.log_alpha).item(),
            'total_steps': self.total_steps
        }
        
        for key, value in metrics.items():
            if key != 'total_steps':
                self.metrics_history[key].append(value)
        
        return metrics
    
    def sample(self, environment, num_samples: int = 100) -> List:
        """Sample trajectories using the trained policy."""
        trajectories = []
        for _ in range(num_samples):
            trajectory = self.sample_trajectory(environment)
            trajectories.append(trajectory)
        return trajectories


class MARSBaseline:
    """Markov Chain Monte Carlo (MARS) baseline."""
    
    def __init__(self, state_dim: int, action_dim: int):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.metrics_history = defaultdict(list)
        self.total_steps = 0
        
    def sample_trajectory(self, environment, max_length: int = 100) -> List:
        """Sample trajectory using random walk."""
        trajectory = [environment.reset()]
        
        for _ in range(max_length):
            # Random action selection
            action = random.randint(0, self.action_dim - 1)
            next_state = environment.step(trajectory[-1], action)
            trajectory.append(next_state)
            
            if environment.is_terminal(next_state):
                break
        
        return trajectory
    
    def metropolis_hastings_step(self, current_trajectory: List, environment) -> List:
        """Perform Metropolis-Hastings step."""
        # Propose new trajectory by modifying random step
        if len(current_trajectory) <= 1:
            return current_trajectory
        
        proposed_trajectory = current_trajectory.copy()
        
        # Randomly modify one step
        step_idx = random.randint(0, len(proposed_trajectory) - 2)
        action = random.randint(0, self.action_dim - 1)
        
        # Generate new trajectory from modified step
        new_trajectory = proposed_trajectory[:step_idx + 1]
        current_state = new_trajectory[-1]
        
        for _ in range(len(proposed_trajectory) - step_idx - 1):
            next_state = environment.step(current_state, action)
            new_trajectory.append(next_state)
            current_state = next_state
            
            if environment.is_terminal(next_state):
                break
        
        # Accept/reject based on reward ratio
        current_reward = environment.get_reward(current_trajectory[-1])
        proposed_reward = environment.get_reward(new_trajectory[-1])
        
        acceptance_prob = min(1.0, proposed_reward / (current_reward + 1e-8))
        
        if random.random() < acceptance_prob:
            return new_trajectory
        else:
            return current_trajectory
    
    def train_step(self, environment, num_mcmc_steps: int = 1000) -> Dict[str, float]:
        """Perform MCMC sampling steps."""
        # Initialize random trajectory
        current_trajectory = self.sample_trajectory(environment)
        best_reward = environment.get_reward(current_trajectory[-1])
        
        rewards = []
        
        for step in range(num_mcmc_steps):
            current_trajectory = self.metropolis_hastings_step(current_trajectory, environment)
            reward = environment.get_reward(current_trajectory[-1])
            rewards.append(reward)
            
            if reward > best_reward:
                best_reward = reward
            
            self.total_steps += 1
        
        # Track metrics
        metrics = {
            'mean_reward': np.mean(rewards),
            'best_reward': best_reward,
            'total_steps': self.total_steps
        }
        
        self.metrics_history['mean_reward'].append(metrics['mean_reward'])
        self.metrics_history['best_reward'].append(metrics['best_reward'])
        
        return metrics
    
    def sample(self, environment, num_samples: int = 100) -> List:
        """Sample trajectories using MCMC."""
        trajectories = []
        
        # Start with random trajectory
        current_trajectory = self.sample_trajectory(environment)
        
        for _ in range(num_samples):
            # Perform several MCMC steps
            for _ in range(10):
                current_trajectory = self.metropolis_hastings_step(current_trajectory, environment)
            
            trajectories.append(current_trajectory.copy())
        
        return trajectories
