import random
import numpy as np
import matplotlib.pyplot as plt
from deap import base, creator, tools
import torch
import torch.nn as nn
import torch.optim as optim

# ---------------------------
# Sphere function
# ---------------------------
def sphere(individual):
    return sum(x**2 for x in individual),
import math
# ---------------------------
# Rastrigin Function – many local minima, good for testing global search
# ---------------------------
def rastrigin(individual):
    A = 10
    n = len(individual)
    return A * n + sum(x**2 - A * math.cos(2 * math.pi * x) for x in individual),

def ellipsoid(individual):
    return sum((i+1)*x**2 for i, x in enumerate(individual)),

# ---------------------------
# Actor-Critic Network (A2C)
# ---------------------------
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.actor_mu = nn.Linear(256, action_dim)
        self.actor_logstd = nn.Parameter(torch.zeros(action_dim))
        self.critic = nn.Linear(256, 1)

    def forward(self, x):
        shared = self.shared(x)
        mu = self.actor_mu(shared)
        std = torch.exp(self.actor_logstd)
        value = self.critic(shared)
        return mu, std, value

# ---------------------------
# A2C SPO Agent
# ---------------------------
class A2CAgent:
    def __init__(self, state_dim, action_dim, lr=1e-5, gamma=0.98, epsilon=0.2):
        self.model = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon

    def select_action(self, state, eval_mode=False):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        mu, std, value = self.model(state_tensor)
        dist = torch.distributions.Normal(mu, std)
        if eval_mode:
            action = mu
        else:
            action = dist.sample()
        log_prob = dist.log_prob(action).sum()
        return action.detach().numpy(), log_prob, value

    def update(self, mb_states, mb_actions, mb_old_log_probs, mb_advantages):
        states = torch.tensor(np.array(mb_states), dtype=torch.float32)
        actions = torch.tensor(np.array(mb_actions), dtype=torch.float32)
        old_log_probs = torch.tensor(mb_old_log_probs, dtype=torch.float32)
        advantages = torch.tensor(mb_advantages, dtype=torch.float32)

        # Forward pass
        mu, std, values = self.model(states)
        dist = torch.distributions.Normal(mu, std)
        log_probs = dist.log_prob(actions).sum(dim=1)

        # SPO policy loss (quadratic penalty)
        ratio = torch.exp(log_probs - old_log_probs)
        policy_loss_elements = ratio * advantages - (advantages.abs() / (2 * self.epsilon)) * (ratio - 1).pow(2)
        actor_loss = -policy_loss_elements.mean()

        # Critic loss
        critic_loss = (advantages.pow(2)).mean()
        total_loss = actor_loss + 0.5 * critic_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

# ---------------------------
# GA Setup
# ---------------------------
DIMENSIONS = 10
POP_SIZE = 50
GENERATIONS = 50
CHILDREN_PER_GEN = 4
BOUND_LOW, BOUND_UP = -5.0, 5.0

creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.FitnessMin)

toolbox = base.Toolbox()
toolbox.register("attr_float", random.uniform, BOUND_LOW, BOUND_UP)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_float, n=DIMENSIONS)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mate", tools.cxBlend, alpha=0.5)
toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=0.1, indpb=0.2)
toolbox.register("select", tools.selTournament, tournsize=3)
toolbox.register("evaluate", sphere)

# ---------------------------
# RL Agent Training (decoupled)
# ---------------------------
def train_rl_agent(episodes=5000, batch_size=128):
    agent = A2CAgent(state_dim=POP_SIZE, action_dim=DIMENSIONS)
    for ep in range(episodes):
        mb_states, mb_actions, mb_log_probs, mb_advantages = [], [], [], []
        for _ in range(batch_size):
            pop_state = np.random.uniform(0, 25, size=POP_SIZE).astype(np.float32)
            action, log_prob, value = agent.select_action(pop_state)
            child_genes = np.clip(action, BOUND_LOW, BOUND_UP)
            child = creator.Individual(child_genes.tolist())
            child_fitness = toolbox.evaluate(child)[0]

            reward = np.max(pop_state) - child_fitness

            mb_states.append(pop_state)
            mb_actions.append(child_genes)
            mb_log_probs.append(log_prob)
            mb_advantages.append(reward)  # advantage = reward - baseline (baseline=0 here)

        agent.update(mb_states, mb_actions, mb_log_probs, mb_advantages)

        if ep % 500 == 0:
            avg_reward = np.mean(mb_advantages)
            print(f"Episode {ep} | Avg reward: {avg_reward:.4f}")

    torch.save(agent.model.state_dict(), "a2c_spo_model.pth")
    print("RL agent training completed and saved!")
    return agent

# ---------------------------
# GA Runner
# ---------------------------
def run_ga_rl(agent=None, guided=False):
    pop = toolbox.population(n=POP_SIZE)
    fitness_history = []
    waste_history = []

    for ind in pop:
        ind.fitness.values = toolbox.evaluate(ind)

    for gen in range(GENERATIONS):
        state = np.array([ind.fitness.values[0] for ind in pop], dtype=np.float32)

        if guided and agent:
            children = []
            for _ in range(CHILDREN_PER_GEN):
                action, _, _ = agent.select_action(state, eval_mode=True)
                child_genes = np.clip(action, BOUND_LOW, BOUND_UP)
                child = creator.Individual(child_genes.tolist())
                child.fitness.values = toolbox.evaluate(child)
                children.append(child)
            pop.extend(children)
            pop = sorted(pop, key=lambda ind: ind.fitness.values[0])[:POP_SIZE]
            gen_waste = sum(1 for c in children if c.fitness.values[0] >= pop[-1].fitness.values[0])
        else:
            children = []
            gen_waste = 0
            for _ in range(CHILDREN_PER_GEN):
                p1, p2 = toolbox.select(pop, 2)
                c1, c2 = toolbox.clone(p1), toolbox.clone(p2)
                if random.random() < 0.9:
                    toolbox.mate(c1, c2)
                toolbox.mutate(c1)
                toolbox.mutate(c2)
                c1.fitness.values = toolbox.evaluate(c1)
                c2.fitness.values = toolbox.evaluate(c2)
                gen_waste += sum(1 for c, p in zip([c1,c2], [p1,p2]) if c.fitness.values[0] >= p.fitness.values[0])
                children.extend([c1,c2])
            pop.extend(children)
            pop = sorted(pop, key=lambda ind: ind.fitness.values[0])[:POP_SIZE]

        fitness_history.append(min(ind.fitness.values[0] for ind in pop))
        waste_history.append(gen_waste)
        print(f"Gen {gen} | Best Fitness: {fitness_history[-1]:.4f} | Waste: {gen_waste}")

    return fitness_history, waste_history

# ---------------------------
# Main
# ---------------------------
def main():
    print("Training RL agent with SPO updates...")
    agent = train_rl_agent(episodes=5000, batch_size=32)

    print("Running baseline GA...")
    fitness_ga, waste_ga = run_ga_rl(guided=False)

    print("Running RL-guided GA...")
    agent.model.eval()
    fitness_rl, waste_rl = run_ga_rl(agent=agent, guided=True)

    # Plot results
    plt.figure(figsize=(10,5))
    plt.plot(fitness_ga, label="Baseline GA")
    plt.plot(fitness_rl, label="RL-guided GA")
    plt.xlabel("Generation")
    plt.ylabel("Best Fitness")
    plt.title("GA Convergence Comparison")
    plt.legend()
    plt.grid()
    plt.show()

    plt.figure(figsize=(10,5))
    plt.plot(waste_ga, label="Baseline GA")
    plt.plot(waste_rl, label="RL-guided GA")
    plt.xlabel("Generation")
    plt.ylabel("Wasted Children")
    plt.title("GA Wasted Children Comparison")
    plt.legend()
    plt.grid()
    plt.show()

if __name__ == "__main__":
    main()
