import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
from collections import deque

# DEAP imports
from deap import base, creator, tools, algorithms
import array


# ---------- Environment: Rastrigin ----------
def rastrigin(x):
    A = 10
    return A * len(x) + sum([(xi ** 2 - A * np.cos(2 * np.pi * xi)) for xi in x])


# DEAP fitness function wrapper
def evaluate_individual(individual):
    return rastrigin(individual),  # Note the comma for tuple return


# ---------- Improved RL Policy Network ----------
class ImprovedChildPolicy(nn.Module):
    def __init__(self, dim, hidden_dim=128):
        super().__init__()
        self.dim = dim

        # Input: parent1 + parent2 + parent fitness + population stats
        input_dim = dim * 2 + 2 + 4  # parents + parent_fits + [best, worst, mean, std]

        # Shared layers
        self.shared_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        # Separate heads for mean and std
        self.mean_net = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, dim),
            nn.Tanh()  # Bounded output
        )

        self.std_net = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, dim),
            nn.Sigmoid()  # Always positive
        )

    def forward(self, state):
        shared_features = self.shared_net(state)

        # Mean scaled to problem bounds
        mean = self.mean_net(shared_features) * 5.12

        # Std with reasonable bounds (0.1 to 2.0)
        std = self.std_net(shared_features) * 1.9 + 0.01

        return mean, std

    def sample_action(self, state):
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return action, log_prob, entropy


# ---------- Improved Value Network ----------
class ImprovedValueNetwork(nn.Module):
    def __init__(self, dim, hidden_dim=128):
        super().__init__()
        input_dim = dim * 2 + 2 + 4  # Same as policy

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)


# ---------- Improved State Representation ----------
def create_state(parent1, parent2, eww, fitnesses):
    """Create rich state representation"""
    p1_fit = rastrigin(parent1)
    p2_fit = rastrigin(parent2)

    # Population statistics
    pop_stats = [
        np.min(fitnesses),  # best fitness
        np.max(fitnesses),  # worst fitness
        np.mean(fitnesses),  # mean fitness
        np.std(fitnesses)  # fitness std
    ]

    # Combine all features
    state_list = list(parent1) + list(parent2) + [p1_fit, p2_fit] + pop_stats
    return torch.tensor(state_list, dtype=torch.float32)


# ---------- Improved Reward Function ----------
def compute_reward(child, parent1, parent2, population, fitnesses):
    """Multi-component reward function"""
    child_fitness = rastrigin(child)
    parent1_fitness = rastrigin(parent1)
    parent2_fitness = rastrigin(parent2)

    # Component 1: Improvement over parents
    best_parent_fitness = min(parent1_fitness, parent2_fitness)
    improvement_reward = (best_parent_fitness - child_fitness) / best_parent_fitness

    # Component 2: Population ranking reward
    better_than_count = sum(1 for f in fitnesses if child_fitness < f)
    ranking_reward = better_than_count / len(fitnesses) - 0.5  # Center around 0

    # Component 3: Diversity bonus (distance from nearest neighbor)
    min_distance = min(np.linalg.norm(np.array(child) - np.array(ind))
                       for ind in population)
    diversity_reward = min(min_distance / 5.0, 1.0)  # Cap at 1.0

    # Combined reward with weights
    total_reward = (2.0 * improvement_reward +
                    1.0 * ranking_reward +
                    0.5 * diversity_reward)

    return total_reward, {
        'improvement': improvement_reward,
        'ranking': ranking_reward,
        'diversity': diversity_reward,
        'child_fitness': child_fitness
    }


# ---------- Experience Buffer ----------
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, log_prob, entropy, value):
        self.buffer.append((state, action, reward, log_prob, entropy, value))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        states, actions, rewards, log_probs, entropies, values = zip(*batch)
        return (torch.stack(states), torch.stack(actions),
                torch.tensor(rewards, dtype=torch.float32),
                torch.stack(log_probs), torch.stack(entropies), torch.stack(values))

    def __len__(self):
        return len(self.buffer)


# ---------- Improved SPO Update ----------
def improved_spo_update(policy, value_net, optimizer_policy, optimizer_value,
                        buffer, batch_size=256, epochs=4, epsilon=0.2):
    if len(buffer) < batch_size:
        return 0, 0

    states, actions, rewards, old_log_probs, entropies, old_values = buffer.sample(batch_size)

    # Compute advantages
    with torch.no_grad():
        values = value_net(states)
        advantages = rewards - values
        returns = rewards

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    total_policy_loss = 0
    total_value_loss = 0

    for epoch in range(epochs):
        # Get current policy outputs
        mean, std = policy(states)
        dist = torch.distributions.Normal(mean, std)
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        new_entropy = dist.entropy().sum(dim=-1)

        # SPO policy loss with quadratic penalty
        ratio = torch.exp(new_log_probs - old_log_probs.detach())
        mb_advantages = advantages.detach()

        # SPO objective: ratio * advantages - quadratic penalty
        policy_loss_elements = (ratio * mb_advantages -
                                (mb_advantages.abs() / (2 * epsilon)) * (ratio - 1).pow(2))
        policy_loss = -policy_loss_elements.mean()

        # Add entropy bonus for exploration
        entropy_loss = -0.01 * new_entropy.mean()

        total_loss = policy_loss + entropy_loss

        optimizer_policy.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
        optimizer_policy.step()

        # Value function update
        current_values = value_net(states)
        value_loss = F.mse_loss(current_values, returns.detach())

        optimizer_value.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(value_net.parameters(), 0.5)
        optimizer_value.step()

        total_policy_loss += total_loss.item()
        total_value_loss += value_loss.item()

    return total_policy_loss / epochs, total_value_loss / epochs


# ---------- Training Function ----------
def train_improved_rl(dim=10, training_episodes=2000):
    policy = ImprovedChildPolicy(dim)
    value_net = ImprovedValueNetwork(dim)

    optimizer_policy = optim.Adam(policy.parameters(), lr=1e-4, weight_decay=1e-5)
    optimizer_value = optim.Adam(value_net.parameters(), lr=1e-4, weight_decay=1e-5)

    buffer = ReplayBuffer(capacity=10000)

    print("Training Improved RL Agent...")

    reward_history = []

    for episode in range(training_episodes):
        # Generate diverse populations
        pop_size = random.randint(30, 100)
        population = [np.random.uniform(-5.12, 5.12, dim) for _ in range(pop_size)]
        fitnesses = np.array([rastrigin(ind) for ind in population])

        episode_rewards = []
        episode_info = []

        # Generate experience
        for step in range(50):
            # Smart parent selection (mix of tournament and fitness-based)
            if random.random() < 0.7:
                # Tournament selection
                parent1 = tournament_selection_classic(population, fitnesses, 5)
                parent2 = tournament_selection_classic(population, fitnesses, 5)
            else:
                # Random selection for diversity
                parent1 = random.choice(population)
                parent2 = random.choice(population)

            # Create state
            state = create_state(parent1, parent2, population, fitnesses)

            # Sample action
            action, log_prob, entropy = policy.sample_action(state.unsqueeze(0))
            action = action.squeeze(0)

            # Ensure bounds
            child = torch.clamp(action, -5.12, 5.12).detach().numpy()

            # Compute reward
            reward, info = compute_reward(child, parent1, parent2, population, fitnesses)

            # Get value estimate
            value = value_net(state.unsqueeze(0)).squeeze(0)

            # Store experience
            buffer.push(state, action, reward, log_prob.squeeze(0),
                        entropy.squeeze(0), value)

            episode_rewards.append(reward)
            episode_info.append(info)

        avg_reward = np.mean(episode_rewards)
        reward_history.append(avg_reward)

        # Update networks
        if len(buffer) > 200:
            policy_loss, value_loss = improved_spo_update(
                policy, value_net, optimizer_policy, optimizer_value, buffer
            )

            if episode % 500 == 0:
                avg_improvement = np.mean([info['improvement'] for info in episode_info])
                avg_ranking = np.mean([info['ranking'] for info in episode_info])
                print(f"Episode {episode}: Reward={avg_reward:.4f}, "
                      f"Improvement={avg_improvement:.4f}, Ranking={avg_ranking:.4f}")

    return policy, value_net, reward_history


# ---------- Classic Tournament Selection (for RL training) ----------
def tournament_selection_classic(population, fitnesses, tournament_size=3):
    tournament_idx = random.sample(range(len(population)), min(tournament_size, len(population)))
    tournament_fitnesses = [fitnesses[i] for i in tournament_idx]
    winner_idx = tournament_idx[np.argmin(tournament_fitnesses)]
    return population[winner_idx]


# ---------- DEAP Setup ----------
def setup_deap(dim):
    """Setup DEAP toolbox for GA operations"""
    # Create fitness and individual classes
    creator.create("FitnessMin", base.Fitness, weights=(-1.0,))  # Minimization
    creator.create("Individual", array.array, typecode='d', fitness=creator.FitnessMin)

    # Initialize toolbox
    toolbox = base.Toolbox()

    # Individual and population generators
    toolbox.register("attr_float", random.uniform, -5.12, 5.12)
    toolbox.register("individual", tools.initRepeat, creator.Individual,
                     toolbox.attr_float, n=dim)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)

    # Genetic operators
    toolbox.register("evaluate", evaluate_individual)
    toolbox.register("mate", tools.cxUniform, indpb=0.5)  # Uniform crossover
    toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=0.1, indpb=0.1)
    toolbox.register("select", tools.selTournament, tournsize=3)

    return toolbox


# ---------- RL-Guided Crossover Function ----------
def rl_crossover(policy, ind1, ind2, population, fitnesses):
    """Custom crossover using RL policy"""
    # Create state representation
    state = create_state(ind1, ind2, [list(p) for p in population], fitnesses)

    # Sample action from policy
    with torch.no_grad():
        action, _, _ = policy.sample_action(state.unsqueeze(0))
        child_values = torch.clamp(action.squeeze(0), -5.12, 5.12).numpy()

    # Create new individual
    child = creator.Individual(child_values)
    return child


# ---------- Standard Crossover + Mutation ----------
def standard_crossover_mutation(toolbox, parent1, parent2):
    """Standard DEAP crossover and mutation"""
    # Clone parents
    child1, child2 = toolbox.clone(parent1), toolbox.clone(parent2)

    # Crossover
    toolbox.mate(child1, child2)

    # Mutation
    toolbox.mutate(child1)
    toolbox.mutate(child2)

    # Ensure bounds
    for i in range(len(child1)):
        child1[i] = np.clip(child1[i], -5.12, 5.12)
        child2[i] = np.clip(child2[i], -5.12, 5.12)

    # Delete fitness values (they need to be recalculated)
    del child1.fitness.values
    del child2.fitness.values

    return child1, child2


# ---------- RL-Guided GA with DEAP ----------
def improved_rl_ga_deap(policy, dim=10, pop_size=50, generations=150, rl_ratio=1.0):
    """RL-guided GA using DEAP framework"""
    toolbox = setup_deap(dim)

    # Create initial population
    population = toolbox.population(n=pop_size)

    # Evaluate initial population
    fitnesses = list(map(toolbox.evaluate, population))
    for ind, fit in zip(population, fitnesses):
        ind.fitness.values = fit

    best_history = []
    mean_history = []

    policy.eval()

    for gen in range(generations):
        # Record statistics
        fits = [ind.fitness.values[0] for ind in population]
        best_history.append(min(fits))
        mean_history.append(np.mean(fits))

        # Select parents for next generation
        parents = toolbox.select(population, pop_size)
        parents = [toolbox.clone(ind) for ind in parents]

        # Apply elitism - keep best individuals
        elite_count = max(2, int(pop_size * 0.15))
        elite = tools.selBest(population, elite_count)

        # Generate offspring
        offspring = []

        # Add elites
        for elite_ind in elite:
            offspring.append(toolbox.clone(elite_ind))

        # Generate rest of population
        while len(offspring) < pop_size:
            # Select two parents
            parent1 = random.choice(parents)
            parent2 = random.choice(parents)

            if random.random() < rl_ratio:
                # Use RL policy
                child = rl_crossover(policy, parent1, parent2, population, fits)
                offspring.append(child)
            else:
                # Use standard operations
                child1, child2 = standard_crossover_mutation(toolbox, parent1, parent2)
                offspring.extend([child1, child2])

        # Trim to population size
        offspring = offspring[:pop_size]

        # Evaluate offspring that need evaluation
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = list(map(toolbox.evaluate, invalid_ind))
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        population[:] = offspring

    return best_history, mean_history


# ---------- Standard GA with DEAP ----------
def standard_ga_deap(dim=10, pop_size=50, generations=150):
    """Standard GA using DEAP framework"""
    toolbox = setup_deap(dim)

    # Statistics tracking
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("min", np.min)
    stats.register("avg", np.mean)

    # Hall of fame to track best individuals
    hof = tools.HallOfFame(1)

    # Create initial population
    population = toolbox.population(n=pop_size)

    # Run the algorithm
    result_pop, log = algorithms.eaSimple(
        population, toolbox,
        cxpb=0.7,  # Crossover probability
        mutpb=0.1,  # Mutation probability
        ngen=generations,
        stats=stats,
        verbose=False
    )

    # Extract best and mean fitness history
    best_history = [record['min'] for record in log]
    mean_history = [record['avg'] for record in log]

    return best_history, mean_history


# ---------- Main Execution ----------
if __name__ == "__main__":
    dim = 10
    generations = 150

    print("Training improved RL agent...")
    policy, value_net, reward_history = train_improved_rl(dim=dim, training_episodes=100000)

    # Plot training progress
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.plot(reward_history)
    plt.title("RL Training Progress")
    plt.xlabel("Episode")
    plt.ylabel("Average Reward")
    plt.grid(True, alpha=0.3)

    print("Running comparisons...")
    best_ga_deap, mean_ga_deap = standard_ga_deap(dim=dim, pop_size=50, generations=generations)
    best_rl_deap, mean_rl_deap = improved_rl_ga_deap(policy, dim=dim, pop_size=50, generations=generations)

    # Plot fitness comparison
    plt.subplot(1, 3, 2)
    plt.plot(best_ga_deap, label="Standard GA (DEAP)", color='blue', linewidth=2)
    plt.plot(best_rl_deap, label="RL-Guided GA (DEAP)", color='red', linewidth=2)
    plt.xlabel("Generation")
    plt.ylabel("Best Fitness")
    plt.title("Best Fitness Comparison")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.yscale('log')

    # Plot mean fitness comparison
    plt.subplot(1, 3, 3)
    plt.plot(mean_ga_deap, label="Standard GA (DEAP)", color='blue', linewidth=2, alpha=0.7)
    plt.plot(mean_rl_deap, label="RL-Guided GA (DEAP)", color='red', linewidth=2, alpha=0.7)
    plt.xlabel("Generation")
    plt.ylabel("Mean Fitness")
    plt.title("Mean Fitness Comparison")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.yscale('log')

    plt.tight_layout()
    plt.show()

    print(f"\nResults (DEAP Implementation):")
    print(f"Standard GA - Best: {min(best_ga_deap):.6f}, Final: {best_ga_deap[-1]:.6f}")
    print(f"RL-Guided GA - Best: {min(best_rl_deap):.6f}, Final: {best_rl_deap[-1]:.6f}")

    if min(best_rl_deap) < min(best_ga_deap):
        improvement = ((min(best_ga_deap) - min(best_rl_deap)) / min(best_ga_deap)) * 100
        print(f"RL improvement: {improvement:.2f}%")
    else:
        degradation = ((min(best_rl_deap) - min(best_ga_deap)) / min(best_ga_deap)) * 100
        print(f"RL degradation: {degradation:.2f}%")

    # Clean up DEAP classes for potential re-runs
    try:
        del creator.FitnessMin
        del creator.Individual
    except AttributeError:
        pass