import random
import numpy as np
import matplotlib.pyplot as plt
from deap import base, creator, tools
import torch
import torch.nn as nn
import torch.optim as optim

# ---------------------------
# Sphere function
# ---------------------------
def sphere(individual):
    return sum(x**2 for x in individual),

# ---------------------------
# Actor-Critic Network (A2C)
# ---------------------------
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.actor_mu = nn.Linear(128, action_dim)   # output child genes
        self.actor_logstd = nn.Parameter(torch.zeros(action_dim))  # log std per gene
        self.critic = nn.Linear(128, 1)             # value estimate

    def forward(self, x):
        shared = self.shared(x)
        mu = self.actor_mu(shared)
        std = torch.exp(self.actor_logstd)
        value = self.critic(shared)
        return mu, std, value

# ---------------------------
# A2C SPO Agent
# ---------------------------
class A2CAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99):
        self.model = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma

    def select_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        mu, std, value = self.model(state_tensor)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum()
        return action.detach().numpy(), log_prob, value

    def update(self, log_probs, values, rewards):
        rewards = torch.tensor(rewards, dtype=torch.float32)
        values = torch.stack(values).squeeze()
        log_probs = torch.stack(log_probs)

        advantages = rewards - values.detach()
        actor_loss = -(log_probs * advantages).mean()
        critic_loss = advantages.pow(2).mean()
        total_loss = actor_loss + 0.5 * critic_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

# ---------------------------
# GA Setup
# ---------------------------
DIMENSIONS = 10
POP_SIZE = 50
GENERATIONS = 50
BOUND_LOW, BOUND_UP = -5.0, 5.0

creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.FitnessMin)

toolbox = base.Toolbox()
toolbox.register("attr_float", random.uniform, BOUND_LOW, BOUND_UP)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_float, n=DIMENSIONS)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mate", tools.cxBlend, alpha=0.5)
toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=0.1, indpb=0.2)
toolbox.register("select", tools.selTournament, tournsize=3)
toolbox.register("evaluate", sphere)

# ---------------------------
# Run GA with RL proposing children
# ---------------------------
def run_ga_rl(agent=None, guided=False):
    pop = toolbox.population(n=POP_SIZE)
    fitness_history = []
    waste_history = []

    # Evaluate initial population
    for ind in pop:
        ind.fitness.values = toolbox.evaluate(ind)

    for gen in range(GENERATIONS):
        state = np.array([ind.fitness.values[0] for ind in pop], dtype=np.float32)

        # RL proposes 2 children directly
        if guided and agent:
            children = []
            log_probs, values, rewards = [], [], []
            for _ in range(2):
                action, log_prob, value = agent.select_action(state)
                # clip child genes to bounds
                child_genes = np.clip(action, BOUND_LOW, BOUND_UP)
                child = creator.Individual(child_genes.tolist())
                child.fitness.values = toolbox.evaluate(child)

                children.append(child)
                log_probs.append(log_prob)
                values.append(value)
                # reward: improvement over worst individual
                reward = max(0.0, pop[-1].fitness.values[0] - child.fitness.values[0])
                rewards.append(reward)

            # Insert children into population (replace worst)
            pop.extend(children)
            pop = sorted(pop, key=lambda ind: ind.fitness.values[0])[:POP_SIZE]

            # Update RL agent
            agent.update(log_probs, values, rewards)
            gen_waste = sum(1 for c in children if c.fitness.values[0] >= pop[-1].fitness.values[0])
        else:
            # Baseline GA: sample parents + crossover + mutation
            children = []
            gen_waste = 0
            for _ in range(2):
                p1, p2 = toolbox.select(pop, 2)
                c1, c2 = toolbox.clone(p1), toolbox.clone(p2)
                if random.random() < 0.9:
                    toolbox.mate(c1, c2)
                toolbox.mutate(c1)
                toolbox.mutate(c2)
                c1.fitness.values = toolbox.evaluate(c1)
                c2.fitness.values = toolbox.evaluate(c2)
                gen_waste += sum(1 for c, p in zip([c1,c2], [p1,p2]) if c.fitness.values[0] >= p.fitness.values[0])
                children.extend([c1,c2])
            pop.extend(children)
            pop = sorted(pop, key=lambda ind: ind.fitness.values[0])[:POP_SIZE]

        fitness_history.append(min(ind.fitness.values[0] for ind in pop))
        waste_history.append(gen_waste)
        print(f"Gen {gen} | Best Fitness: {fitness_history[-1]:.4f} | Waste: {gen_waste}")

    return fitness_history, waste_history

# ---------------------------
# Main
# ---------------------------
def main():
    # Baseline GA
    print("Running baseline GA...")
    fitness_ga, waste_ga = run_ga_rl(guided=False)

    # RL-guided GA (A2C SPO)
    print("Running RL-guided GA...")
    agent = A2CAgent(state_dim=POP_SIZE, action_dim=DIMENSIONS)
    fitness_rl, waste_rl = run_ga_rl(agent=agent, guided=True)

    # Plot results
    plt.figure(figsize=(10,5))
    plt.plot(fitness_ga, label="Baseline GA")
    plt.plot(fitness_rl, label="RL-guided GA")
    plt.xlabel("Generation")
    plt.ylabel("Best Fitness")
    plt.title("GA Convergence Comparison")
    plt.legend()
    plt.grid()
    plt.show()

    plt.figure(figsize=(10,5))
    plt.plot(waste_ga, label="Baseline GA")
    plt.plot(waste_rl, label="RL-guided GA")
    plt.xlabel("Generation")
    plt.ylabel("Wasted Children")
    plt.title("GA Wasted Children Comparison")
    plt.legend()
    plt.grid()
    plt.show()

if __name__ == "__main__":
    main()
