import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
from collections import deque


# ---------- Environment: Rastrigin ----------
def rastrigin(x):
    A = 10
    return A * len(x) + sum([(xi ** 2 - A * np.cos(2 * np.pi * xi)) for xi in x])


# ---------- Improved RL Policy Network ----------
class ImprovedChildPolicy(nn.Module):
    def __init__(self, dim, hidden_dim=128):
        super().__init__()
        self.dim = dim

        # Input: parent1 + parent2 + parent fitness + population stats
        input_dim = dim * 2 + 2 + 4  # parents + parent_fits + [best, worst, mean, std]

        # Shared layers
        self.shared_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        # Separate heads for mean and std
        self.mean_net = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, dim),
            nn.Tanh()  # Bounded output
        )

        self.std_net = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, dim),
            nn.Sigmoid()  # Always positive
        )

    def forward(self, state):
        shared_features = self.shared_net(state)

        # Mean scaled to problem bounds
        mean = self.mean_net(shared_features) * 5.12

        # Std with reasonable bounds (0.1 to 2.0)
        std = self.std_net(shared_features) * 1.9 + 0.01

        return mean, std

    def sample_action(self, state):
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return action, log_prob, entropy


# ---------- Improved Value Network ----------
class ImprovedValueNetwork(nn.Module):
    def __init__(self, dim, hidden_dim=128):
        super().__init__()
        input_dim = dim * 2 + 2 + 4  # Same as policy

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, state):
        return self.net(state).squeeze(-1)


# ---------- Improved State Representation ----------
def create_state(parent1, parent2, population, fitnesses):
    """Create rich state representation"""
    p1_fit = rastrigin(parent1)
    p2_fit = rastrigin(parent2)

    # Population statistics
    pop_stats = [
        np.min(fitnesses),  # best fitness
        np.max(fitnesses),  # worst fitness
        np.mean(fitnesses),  # mean fitness
        np.std(fitnesses)  # fitness std
    ]

    # Combine all features
    state_list = list(parent1) + list(parent2) + [p1_fit, p2_fit] + pop_stats
    return torch.tensor(state_list, dtype=torch.float32)


# ---------- Improved Reward Function ----------
def compute_reward(child, parent1, parent2, population, fitnesses):
    """Multi-component reward function"""
    child_fitness = rastrigin(child)
    parent1_fitness = rastrigin(parent1)
    parent2_fitness = rastrigin(parent2)

    # Component 1: Improvement over parents
    best_parent_fitness = min(parent1_fitness, parent2_fitness)
    improvement_reward = (best_parent_fitness - child_fitness) / best_parent_fitness

    # Component 2: Population ranking reward
    better_than_count = sum(1 for f in fitnesses if child_fitness < f)
    ranking_reward = better_than_count / len(fitnesses) - 0.5  # Center around 0

    # Component 3: Diversity bonus (distance from nearest neighbor)
    min_distance = min(np.linalg.norm(np.array(child) - np.array(ind))
                       for ind in population)
    diversity_reward = min(min_distance / 5.0, 1.0)  # Cap at 1.0

    # Combined reward with weights
    total_reward = (2.0 * improvement_reward +
                    1.0 * ranking_reward +
                    0.5 * diversity_reward)

    return total_reward, {
        'improvement': improvement_reward,
        'ranking': ranking_reward,
        'diversity': diversity_reward,
        'child_fitness': child_fitness
    }


# ---------- Experience Buffer ----------
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, log_prob, entropy, value):
        self.buffer.append((state, action, reward, log_prob, entropy, value))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        states, actions, rewards, log_probs, entropies, values = zip(*batch)
        return (torch.stack(states), torch.stack(actions),
                torch.tensor(rewards, dtype=torch.float32),
                torch.stack(log_probs), torch.stack(entropies), torch.stack(values))

    def __len__(self):
        return len(self.buffer)


# ---------- Improved SPO Update ----------
def improved_spo_update(policy, value_net, optimizer_policy, optimizer_value,
                        buffer, batch_size=256, epochs=4, epsilon=0.2):
    if len(buffer) < batch_size:
        return 0, 0

    states, actions, rewards, old_log_probs, entropies, old_values = buffer.sample(batch_size)

    # Compute advantages
    with torch.no_grad():
        values = value_net(states)
        advantages = rewards - values
        returns = rewards

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    total_policy_loss = 0
    total_value_loss = 0

    for epoch in range(epochs):
        # Get current policy outputs
        mean, std = policy(states)
        dist = torch.distributions.Normal(mean, std)
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        new_entropy = dist.entropy().sum(dim=-1)

        # SPO policy loss with quadratic penalty
        ratio = torch.exp(new_log_probs - old_log_probs.detach())
        mb_advantages = advantages.detach()

        # SPO objective: ratio * advantages - quadratic penalty
        policy_loss_elements = (ratio * mb_advantages -
                                (mb_advantages.abs() / (2 * epsilon)) * (ratio - 1).pow(2))
        policy_loss = -policy_loss_elements.mean()

        # Add entropy bonus for exploration
        entropy_loss = -0.01 * new_entropy.mean()

        total_loss = policy_loss + entropy_loss

        optimizer_policy.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
        optimizer_policy.step()

        # Value function update
        current_values = value_net(states)
        value_loss = F.mse_loss(current_values, returns.detach())

        optimizer_value.zero_grad()
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(value_net.parameters(), 0.5)
        optimizer_value.step()

        total_policy_loss += total_loss.item()
        total_value_loss += value_loss.item()

    return total_policy_loss / epochs, total_value_loss / epochs


# ---------- Training Function ----------
def train_improved_rl(dim=10, training_episodes=2000):
    policy = ImprovedChildPolicy(dim)
    value_net = ImprovedValueNetwork(dim)

    optimizer_policy = optim.Adam(policy.parameters(), lr=3e-4, weight_decay=1e-5)
    optimizer_value = optim.Adam(value_net.parameters(), lr=3e-4, weight_decay=1e-5)

    buffer = ReplayBuffer(capacity=10000)

    print("Training Improved RL Agent...")

    reward_history = []

    for episode in range(training_episodes):
        # Generate diverse populations
        pop_size = random.randint(30, 100)
        population = [np.random.uniform(-5.12, 5.12, dim) for _ in range(pop_size)]
        fitnesses = np.array([rastrigin(ind) for ind in population])

        episode_rewards = []
        episode_info = []

        # Generate experience
        for step in range(20):
            # Smart parent selection (mix of tournament and fitness-based)
            if random.random() < 0.7:
                # Tournament selection
                parent1 = tournament_selection(population, fitnesses, 3)
                parent2 = tournament_selection(population, fitnesses, 3)
            else:
                # Random selection for diversity
                parent1 = random.choice(population)
                parent2 = random.choice(population)

            # Create state
            state = create_state(parent1, parent2, population, fitnesses)

            # Sample action
            action, log_prob, entropy = policy.sample_action(state.unsqueeze(0))
            action = action.squeeze(0)

            # Ensure bounds
            child = torch.clamp(action, -5.12, 5.12).detach().numpy()

            # Compute reward
            reward, info = compute_reward(child, parent1, parent2, population, fitnesses)

            # Get value estimate
            value = value_net(state.unsqueeze(0)).squeeze(0)

            # Store experience
            buffer.push(state, action, reward, log_prob.squeeze(0),
                        entropy.squeeze(0), value)

            episode_rewards.append(reward)
            episode_info.append(info)

        avg_reward = np.mean(episode_rewards)
        reward_history.append(avg_reward)

        # Update networks
        if len(buffer) > 100:
            policy_loss, value_loss = improved_spo_update(
                policy, value_net, optimizer_policy, optimizer_value, buffer
            )

            if episode % 100 == 0:
                avg_improvement = np.mean([info['improvement'] for info in episode_info])
                avg_ranking = np.mean([info['ranking'] for info in episode_info])
                print(f"Episode {episode}: Reward={avg_reward:.4f}, "
                      f"Improvement={avg_improvement:.4f}, Ranking={avg_ranking:.4f}")

    return policy, value_net, reward_history


# ---------- Tournament Selection ----------
def tournament_selection(population, fitnesses, tournament_size=3):
    tournament_idx = random.sample(range(len(population)), min(tournament_size, len(population)))
    tournament_fitnesses = [fitnesses[i] for i in tournament_idx]
    winner_idx = tournament_idx[np.argmin(tournament_fitnesses)]
    return population[winner_idx]


# ---------- Standard Operations ----------
def standard_crossover_mutation(parent1, parent2, mutation_rate=0.1, mutation_strength=0.1):
    # Uniform crossover
    child = []
    for i in range(len(parent1)):
        if random.random() < 0.5:
            child.append(parent1[i])
        else:
            child.append(parent2[i])

    # Gaussian mutation
    for i in range(len(child)):
        if random.random() < mutation_rate:
            child[i] += np.random.normal(0, mutation_strength)
            child[i] = np.clip(child[i], -5.12, 5.12)

    return child


# ---------- Improved RL-Guided GA ----------
def improved_rl_ga(policy, dim=10, pop_size=50, generations=150, rl_ratio=1.0):
    population = [np.random.uniform(-5.12, 5.12, dim) for _ in range(pop_size)]
    best_history = []
    mean_history = []

    policy.eval()

    for gen in range(generations):
        fitnesses = np.array([rastrigin(ind) for ind in population])
        best_history.append(fitnesses.min())
        mean_history.append(fitnesses.mean())

        new_population = []

        # Strong elitism
        elite_count = max(2, int(pop_size * 0.15))
        sorted_idx = np.argsort(fitnesses)
        for i in range(elite_count):
            new_population.append(population[sorted_idx[i]])

        # Generate rest of population
        while len(new_population) < pop_size:
            parent1 = tournament_selection(population, fitnesses, 3)
            parent2 = tournament_selection(population, fitnesses, 3)

            if random.random() < rl_ratio:
                # Use RL policy with improved state
                state = create_state(parent1, parent2, population, fitnesses)

                with torch.no_grad():
                    action, _, _ = policy.sample_action(state.unsqueeze(0))
                    child = torch.clamp(action.squeeze(0), -5.12, 5.12).numpy()
                    new_population.append(child.tolist())
            else:
                # Use standard operations
                child = standard_crossover_mutation(parent1, parent2)
                new_population.append(child)

        population = new_population

    return best_history, mean_history


# ---------- Standard GA ----------
def standard_ga(dim=10, pop_size=50, generations=150):
    population = [np.random.uniform(-5.12, 5.12, dim) for _ in range(pop_size)]
    best_history = []
    mean_history = []

    for gen in range(generations):
        fitnesses = np.array([rastrigin(ind) for ind in population])
        best_history.append(fitnesses.min())
        mean_history.append(fitnesses.mean())

        new_population = []

        # Elitism
        elite_count = max(2, int(pop_size * 0.15))
        sorted_idx = np.argsort(fitnesses)
        for i in range(elite_count):
            new_population.append(population[sorted_idx[i]])

        while len(new_population) < pop_size:
            #parent1 = tournament_selection(population, fitnesses, 3)
            #parent2 = tournament_selection(population, fitnesses, 3)
            parent1 = random.choice(population)
            parent2 = random.choice(population)
            child = standard_crossover_mutation(parent1, parent2)
            new_population.append(child)

        population = new_population

    return best_history, mean_history


# ---------- Main Execution ----------
if __name__ == "__main__":
    dim = 10
    generations = 150

    print("Training improved RL agent...")
    policy, value_net, reward_history = train_improved_rl(dim=dim, training_episodes=50000)

    # Plot training progress
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(reward_history)
    plt.title("RL Training Progress")
    plt.xlabel("Episode")
    plt.ylabel("Average Reward")
    plt.grid(True, alpha=0.3)

    print("Running comparisons...")
    best_ga, mean_ga = standard_ga(dim=dim, generations=generations)
    best_rl, mean_rl = improved_rl_ga(policy, dim=dim, generations=generations)

    # Plot comparison
    plt.subplot(1, 2, 2)
    plt.plot(best_ga, label="Standard GA", color='blue', linewidth=2)
    plt.plot(best_rl, label="Improved RL-GA", color='red', linewidth=2)
    plt.xlabel("Generation")
    plt.ylabel("Best Fitness")
    plt.title("Performance Comparison")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(f"\nResults:")
    print(f"Standard GA - Best: {min(best_ga):.4f}, Final: {best_ga[-1]:.4f}")
    print(f"RL-Guided GA - Best: {min(best_rl):.4f}, Final: {best_rl[-1]:.4f}")

    if min(best_rl) < min(best_ga):
        improvement = ((min(best_ga) - min(best_rl)) / min(best_ga)) * 100
        print(f"RL improvement: {improvement:.2f}%")
    else:
        degradation = ((min(best_rl) - min(best_ga)) / min(best_ga)) * 100
        print(f"RL degradation: {degradation:.2f}%")