"""
RQ1: Computational Efficiency - Amortized DATE-GFN Implementation

This experiment addresses the computational overhead weakness by implementing
A-DATE-GFN with strategic critic re-evaluation and periodic distillation.

Research Question: Can amortized updates reduce computational cost by 2-3x 
while maintaining 95%+ performance of standard DATE-GFN?
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
import time
import random
from collections import deque, defaultdict
from typing import List, Tuple, Dict, Optional
import argparse
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class HypergridEnvironment:
    """Hypergrid environment - computationally intensive baseline task"""
    
    def __init__(self, ndim: int = 4, height: int = 8, R0: float = 0.01, R1: float = 0.5, R2: float = 2.0):
        self.ndim = ndim
        self.height = height
        self.R0, self.R1, self.R2 = R0, R1, R2
        self.reset()
        
    def reset(self):
        self.state = [0] * self.ndim
        return self.get_state_tensor()
    
    def get_state_tensor(self):
        return torch.tensor(self.state, dtype=torch.float32)
    
    def get_valid_actions(self):
        """Get valid actions from current state"""
        actions = []
        for dim in range(self.ndim):
            if self.state[dim] < self.height - 1:
                actions.append(dim)
        actions.append(self.ndim)  # Stop action
        return actions
    
    def step(self, action: int):
        """Execute action and return reward, done, info"""
        if action == self.ndim:  # Stop action
            reward = self.compute_reward()
            return self.get_state_tensor(), reward, True, {}
        
        if action < self.ndim and self.state[action] < self.height - 1:
            self.state[action] += 1
            
        return self.get_state_tensor(), 0.0, False, {}
    
    def compute_reward(self):
        """Compute final reward based on state"""
        # Complex reward function for computational intensity
        if all(s >= self.height - 1 for s in self.state):
            return self.R2  # Maximum reward
        elif any(s >= self.height // 2 for s in self.state):
            return self.R1  # Medium reward
        else:
            return self.R0  # Minimum reward

class PolicyNetwork(nn.Module):
    """Policy network for hypergrid task"""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
    def forward(self, state, action_mask=None):
        logits = self.network(state)
        if action_mask is not None:
            logits = logits + (action_mask - 1) * 1e9
        return F.log_softmax(logits, dim=-1)

class CriticNetwork(nn.Module):
    """Critic network for evaluating state values"""
    
    def __init__(self, state_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state):
        return self.network(state)

class AmortizedDATEGFN:
    """Amortized DATE-GFN implementation with strategic critic updates"""
    
    def __init__(self, 
                 state_dim: int,
                 action_dim: int,
                 critic_population_size: int = 20,
                 rho: float = 0.5,  # Critic re-evaluation fraction
                 M: int = 5,        # Student update frequency
                 lambda_param: float = 0.1,
                 lr: float = 1e-3):
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.rho = rho
        self.M = M
        self.lambda_param = lambda_param
        
        # Initialize student policy
        self.student = PolicyNetwork(state_dim, action_dim)
        self.student_optimizer = torch.optim.Adam(self.student.parameters(), lr=lr)
        
        # Initialize critic population
        self.critic_population = [CriticNetwork(state_dim) for _ in range(critic_population_size)]
        self.critic_fitness = [0.0] * critic_population_size
        self.critic_last_evaluated = [0] * critic_population_size
        
        # Tracking metrics
        self.generation = 0
        self.total_critic_evaluations = 0
        self.total_student_updates = 0
        self.computational_cost = 0.0
        
        # Performance buffers
        self.reward_buffer = deque(maxlen=100)
        self.l1_distance_buffer = deque(maxlen=100)
        self.modes_discovered = set()
        
    def evaluate_critic_fitness(self, critic_idx: int, env: HypergridEnvironment, 
                               num_trajectories: int = 50) -> float:
        """Evaluate fitness of a single critic"""
        critic = self.critic_population[critic_idx]
        total_reward = 0.0
        teachability_loss = 0.0
        
        start_time = time.time()
        
        for _ in range(num_trajectories):
            env.reset()
            trajectory_reward = 0.0
            trajectory_length = 0
            
            while True:
                state = env.get_state_tensor().unsqueeze(0)
                valid_actions = env.get_valid_actions()
                
                # Get critic guidance
                critic_value = critic(state)
                
                # Get student policy
                action_mask = torch.zeros(self.action_dim)
                action_mask[valid_actions] = 1.0
                student_log_probs = self.student(state, action_mask.unsqueeze(0))
                
                # Compute teachability (KL divergence approximation)
                critic_guidance = F.softmax(critic_value * action_mask.unsqueeze(0), dim=1)
                student_probs = torch.exp(student_log_probs)
                kl_loss = F.kl_div(student_log_probs, critic_guidance, reduction='batchmean')
                teachability_loss += kl_loss.item()
                
                # Sample action from student policy
                action_probs = torch.exp(student_log_probs)
                action = torch.multinomial(action_probs, 1).item()
                
                # Execute action
                _, reward, done, _ = env.step(action)
                trajectory_reward += reward
                trajectory_length += 1
                
                if done or trajectory_length > 50:
                    break
            
            total_reward += trajectory_reward
        
        # Compute distillation-aware fitness
        avg_reward = total_reward / num_trajectories
        avg_teachability = teachability_loss / num_trajectories
        fitness = avg_reward - self.lambda_param * avg_teachability
        
        # Track computational cost
        self.computational_cost += time.time() - start_time
        self.total_critic_evaluations += 1
        
        return fitness
    
    def evolutionary_step(self):
        """Perform one generation of evolution with amortized updates"""
        # Select subset of critics for re-evaluation
        population_size = len(self.critic_population)
        num_to_evaluate = max(1, int(self.rho * population_size))
        
        # Prioritize critics that haven't been evaluated recently
        evaluation_priorities = [(i, self.generation - self.critic_last_evaluated[i]) 
                               for i in range(population_size)]
        evaluation_priorities.sort(key=lambda x: x[1], reverse=True)
        
        critics_to_evaluate = [idx for idx, _ in evaluation_priorities[:num_to_evaluate]]
        
        # Re-evaluate selected critics
        env = HypergridEnvironment()
        for critic_idx in critics_to_evaluate:
            self.critic_fitness[critic_idx] = self.evaluate_critic_fitness(critic_idx, env)
            self.critic_last_evaluated[critic_idx] = self.generation
        
        # Evolutionary operations (selection, crossover, mutation)
        self.evolve_population()
        
        self.generation += 1
    
    def evolve_population(self):
        """Perform selection, crossover, and mutation"""
        population_size = len(self.critic_population)
        
        # Tournament selection
        new_population = []
        new_fitness = []
        
        for _ in range(population_size):
            # Select parents via tournament
            tournament_size = 3
            tournament_indices = random.sample(range(population_size), tournament_size)
            parent_idx = max(tournament_indices, key=lambda i: self.critic_fitness[i])
            
            # Create offspring (mutation)
            parent = self.critic_population[parent_idx]
            offspring = CriticNetwork(self.state_dim)
            
            # Copy parent parameters with mutation
            with torch.no_grad():
                for child_param, parent_param in zip(offspring.parameters(), parent.parameters()):
                    noise = torch.randn_like(parent_param) * 0.01
                    child_param.copy_(parent_param + noise)
            
            new_population.append(offspring)
            new_fitness.append(0.0)  # Will be evaluated when needed
        
        self.critic_population = new_population
        self.critic_fitness = new_fitness
    
    def update_student(self, env: HypergridEnvironment, num_trajectories: int = 100):
        """Update student policy via distillation from best critic"""
        # Find best critic
        best_critic_idx = max(range(len(self.critic_fitness)), 
                            key=lambda i: self.critic_fitness[i])
        best_critic = self.critic_population[best_critic_idx]
        
        total_loss = 0.0
        start_time = time.time()
        
        for _ in range(num_trajectories):
            env.reset()
            trajectory_loss = 0.0
            
            while True:
                state = env.get_state_tensor().unsqueeze(0)
                valid_actions = env.get_valid_actions()
                
                # Get action mask
                action_mask = torch.zeros(self.action_dim)
                action_mask[valid_actions] = 1.0
                
                # Get critic guidance
                with torch.no_grad():
                    critic_value = best_critic(state)
                    critic_guidance = F.softmax(critic_value * action_mask.unsqueeze(0), dim=1)
                
                # Get student policy
                student_log_probs = self.student(state, action_mask.unsqueeze(0))
                
                # Compute distillation loss
                distillation_loss = F.kl_div(student_log_probs, critic_guidance, reduction='batchmean')
                trajectory_loss += distillation_loss
                
                # Sample action and step
                action_probs = torch.exp(student_log_probs)
                action = torch.multinomial(action_probs, 1).item()
                _, reward, done, _ = env.step(action)
                
                if done or len(env.state) > 50:
                    break
            
            total_loss += trajectory_loss
        
        # Update student
        avg_loss = total_loss / num_trajectories
        self.student_optimizer.zero_grad()
        avg_loss.backward()
        self.student_optimizer.step()
        
        self.computational_cost += time.time() - start_time
        self.total_student_updates += 1
        
        return avg_loss.item()
    
    def evaluate_performance(self, env: HypergridEnvironment, num_episodes: int = 100):
        """Evaluate current performance"""
        total_reward = 0.0
        episode_rewards = []
        modes_found = set()
        
        for _ in range(num_episodes):
            env.reset()
            episode_reward = 0.0
            
            while True:
                state = env.get_state_tensor().unsqueeze(0)
                valid_actions = env.get_valid_actions()
                
                action_mask = torch.zeros(self.action_dim)
                action_mask[valid_actions] = 1.0
                
                with torch.no_grad():
                    student_log_probs = self.student(state, action_mask.unsqueeze(0))
                    action_probs = torch.exp(student_log_probs)
                    action = torch.multinomial(action_probs, 1).item()
                
                _, reward, done, _ = env.step(action)
                episode_reward += reward
                
                if done:
                    if reward > 1.5:  # High reward mode
                        modes_found.add(tuple(env.state))
                    break
            
            episode_rewards.append(episode_reward)
            total_reward += episode_reward
        
        avg_reward = total_reward / num_episodes
        reward_std = np.std(episode_rewards)
        
        # Compute L1 distance to optimal policy (approximation)
        l1_distance = max(0, 2.0 - avg_reward)  # Optimal reward is 2.0
        
        return {
            'avg_reward': avg_reward,
            'reward_std': reward_std,
            'l1_distance': l1_distance,
            'modes_discovered': len(modes_found),
            'modes_found': modes_found
        }

def run_amortized_experiment(config: Dict):
    """Run single amortized DATE-GFN experiment"""
    
    # Initialize wandb
    run_name = f"A-DATE-GFN (ρ={config['rho']}, M={config['M']})"
    wandb.init(
        project="DATE_GFN_Computational_Efficiency",
        name=run_name,
        config=config
    )
    
    # Setup environment and agent
    env = HypergridEnvironment(ndim=config['ndim'], height=config['height'])
    agent = AmortizedDATEGFN(
        state_dim=config['ndim'],
        action_dim=config['ndim'] + 1,  # +1 for stop action
        critic_population_size=config['population_size'],
        rho=config['rho'],
        M=config['M'],
        lambda_param=config['lambda_param']
    )
    
    print(f"🚀 Starting {run_name}")
    print(f"   Configuration: {config}")
    
    start_time = time.time()
    
    # Training loop
    for generation in range(config['num_generations']):
        # Evolutionary step
        agent.evolutionary_step()
        
        # Periodic student update
        if generation % agent.M == 0:
            student_loss = agent.update_student(env)
            
            # Evaluate performance
            if generation % config['eval_frequency'] == 0:
                performance = agent.evaluate_performance(env)
                
                # Log metrics
                wall_clock_time = time.time() - start_time
                performance_per_hour = performance['avg_reward'] / (wall_clock_time / 3600)
                
                wandb.log({
                    'generation': generation,
                    'avg_reward': performance['avg_reward'],
                    'reward_std': performance['reward_std'],
                    'l1_distance': performance['l1_distance'],
                    'modes_discovered': performance['modes_discovered'],
                    'student_loss': student_loss,
                    'computational_cost': agent.computational_cost,
                    'wall_clock_time': wall_clock_time,
                    'performance_per_hour': performance_per_hour,
                    'total_critic_evaluations': agent.total_critic_evaluations,
                    'total_student_updates': agent.total_student_updates,
                    'critic_evaluation_rate': agent.rho,
                    'student_update_frequency': agent.M
                })
                
                print(f"  Gen {generation:4d}: Reward={performance['avg_reward']:.3f}, "
                      f"L1={performance['l1_distance']:.3f}, Modes={performance['modes_discovered']}, "
                      f"Time={wall_clock_time:.1f}s, Perf/Hr={performance_per_hour:.2f}")
    
    # Final evaluation
    final_performance = agent.evaluate_performance(env, num_episodes=500)
    total_time = time.time() - start_time
    
    # Log final results
    wandb.log({
        'final_avg_reward': final_performance['avg_reward'],
        'final_l1_distance': final_performance['l1_distance'],
        'final_modes_discovered': final_performance['modes_discovered'],
        'total_wall_clock_time': total_time,
        'final_performance_per_hour': final_performance['avg_reward'] / (total_time / 3600),
        'computational_efficiency': final_performance['avg_reward'] / agent.computational_cost,
        'total_flops_estimate': agent.total_critic_evaluations * 1e6 + agent.total_student_updates * 5e5
    })
    
    print(f"✅ {run_name} completed in {total_time:.1f}s")
    print(f"   Final Performance: Reward={final_performance['avg_reward']:.3f}, "
          f"L1={final_performance['l1_distance']:.3f}, Modes={final_performance['modes_discovered']}")
    
    wandb.finish()
    
    return final_performance, total_time

def main():
    """Main experiment launcher for RQ1"""
    
    parser = argparse.ArgumentParser(description='RQ1: Computational Efficiency Experiments')
    parser.add_argument('--mode', choices=['single', 'ablation', 'comparison'], default='ablation')
    parser.add_argument('--rho', type=float, default=0.5)
    parser.add_argument('--M', type=int, default=5)
    args = parser.parse_args()
    
    base_config = {
        'ndim': 4,
        'height': 8,
        'population_size': 20,
        'num_generations': 200,
        'eval_frequency': 10,
        'lambda_param': 0.1
    }
    
    if args.mode == 'single':
        # Single experiment
        config = {**base_config, 'rho': args.rho, 'M': args.M}
        run_amortized_experiment(config)
        
    elif args.mode == 'ablation':
        # Ablation study
        rho_values = [0.3, 0.5, 0.7, 1.0]
        M_values = [1, 3, 5, 10]
        
        results = []
        
        for rho in rho_values:
            for M in M_values:
                config = {**base_config, 'rho': rho, 'M': M}
                performance, wall_time = run_amortized_experiment(config)
                
                results.append({
                    'rho': rho,
                    'M': M,
                    'performance': performance,
                    'wall_time': wall_time,
                    'efficiency': performance['avg_reward'] / wall_time
                })
        
        # Print summary
        print("\n" + "="*80)
        print("RQ1 COMPUTATIONAL EFFICIENCY ABLATION SUMMARY")
        print("="*80)
        print(f"{'Config':<15} {'Avg Reward':<12} {'L1 Dist':<10} {'Modes':<8} {'Time(s)':<10} {'Perf/Hr':<10}")
        print("-"*80)
        
        for result in results:
            config_str = f"ρ={result['rho']},M={result['M']}"
            perf = result['performance']
            print(f"{config_str:<15} {perf['avg_reward']:<12.3f} {perf['l1_distance']:<10.3f} "
                  f"{perf['modes_discovered']:<8d} {result['wall_time']:<10.1f} "
                  f"{result['efficiency']:<10.3f}")
        
        # Find best configuration
        best_result = max(results, key=lambda x: x['efficiency'])
        print(f"\n🏆 Best Configuration: ρ={best_result['rho']}, M={best_result['M']}")
        print(f"   Efficiency: {best_result['efficiency']:.3f} performance/second")
        
    elif args.mode == 'comparison':
        # Compare against standard DATE-GFN
        print("Running comparison between Standard DATE-GFN and A-DATE-GFN...")
        
        # Standard DATE-GFN (ρ=1.0, M=1)
        standard_config = {**base_config, 'rho': 1.0, 'M': 1}
        standard_perf, standard_time = run_amortized_experiment(standard_config)
        
        # Amortized DATE-GFN (optimal configuration)
        amortized_config = {**base_config, 'rho': 0.5, 'M': 5}
        amortized_perf, amortized_time = run_amortized_experiment(amortized_config)
        
        # Comparison summary
        print("\n" + "="*60)
        print("STANDARD vs AMORTIZED DATE-GFN COMPARISON")
        print("="*60)
        print(f"Standard DATE-GFN:")
        print(f"  Reward: {standard_perf['avg_reward']:.3f}")
        print(f"  L1 Distance: {standard_perf['l1_distance']:.3f}")
        print(f"  Modes: {standard_perf['modes_discovered']}")
        print(f"  Time: {standard_time:.1f}s")
        print(f"  Efficiency: {standard_perf['avg_reward']/standard_time:.3f}")
        
        print(f"\nAmortized DATE-GFN:")
        print(f"  Reward: {amortized_perf['avg_reward']:.3f}")
        print(f"  L1 Distance: {amortized_perf['l1_distance']:.3f}")
        print(f"  Modes: {amortized_perf['modes_discovered']}")
        print(f"  Time: {amortized_time:.1f}s")
        print(f"  Efficiency: {amortized_perf['avg_reward']/amortized_time:.3f}")
        
        speedup = standard_time / amortized_time
        performance_ratio = amortized_perf['avg_reward'] / standard_perf['avg_reward']
        efficiency_gain = (amortized_perf['avg_reward']/amortized_time) / (standard_perf['avg_reward']/standard_time)
        
        print(f"\n📊 RESULTS:")
        print(f"  Speedup: {speedup:.2f}x")
        print(f"  Performance Retention: {performance_ratio:.1%}")
        print(f"  Efficiency Gain: {efficiency_gain:.2f}x")
        
        if performance_ratio >= 0.95 and speedup >= 2.0:
            print("✅ SUCCESS: A-DATE-GFN achieves 95%+ performance in 2x+ speedup!")
        else:
            print("⚠️  Results do not meet success criteria")

if __name__ == "__main__":
    main()
