import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import math
import copy
from collections import deque
from lib.prune import prune_x_pruner
from lib.eval import eval_ppl
from lib.gptree import GPTree

def online_rl_search(args, model, tokenizer, device, engine, layer_idx, layer_sparsity):
    print(f"\n=== Online Reinforcement Learning Search for Layer {layer_idx} ===")

    import time
    random.seed(time.time())
    np.random.seed(int(time.time()))

    original_state = copy.deepcopy(model.state_dict())

    # Define basic action space
    ACTIONS = {
        'up': (0, 0.1),    # +g
        'down': (0, -0.1),  # -g
        'left': (-0.1, 0),  # -w
        'right': (0.1, 0)   # +w
    }

    # Noise policy network
    class NoisyActor(nn.Module):
        def __init__(self):
            super(NoisyActor, self).__init__()
            self.fc1 = nn.Linear(2, 64)
            self.noise1 = nn.Linear(2, 64)  # Noise injection layer
            self.fc2 = nn.Linear(64, 32)
            self.noise2 = nn.Linear(64, 32)
            self.fc3 = nn.Linear(32, len(ACTIONS))

        def forward(self, x, noise=True):
            noise_factor = 0.2  # Noise intensity coefficient
            if noise:
                x = F.relu(self.fc1(x) + (self.noise1(torch.randn_like(x)) * noise_factor))
                x = F.relu(self.fc2(x) + (self.noise2(torch.randn_like(x)) * noise_factor))
            else:
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
            return F.softmax(self.fc3(x), dim=-1)

    class Critic(nn.Module):
        def __init__(self):
            super(Critic, self).__init__()
            self.fc1 = nn.Linear(2, 64)
            self.fc2 = nn.Linear(64, 32)
            self.fc3 = nn.Linear(32, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            return self.fc3(x)

    actor = NoisyActor().to(device)
    critic = Critic().to(device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=0.0003)
    critic_optimizer = torch.optim.Adam(critic.parameters(), lr=0.001)

    # Experience playback buffer
    class ExperienceReplay:
        def __init__(self, capacity):
            self.buffer = deque(maxlen=capacity)
            
        def add(self, experience):
            self.buffer.append(experience)
            
        def sample(self, batch_size):
            return random.sample(self.buffer, min(len(self.buffer), batch_size))

    replay_buffer = ExperienceReplay(1000)

    # Adaptive exploration controller
    class EpsilonController:
        def __init__(self, start=1.0, end=0.1, decay=0.97):
            self.value = start
            self.end = end
            self.decay = decay
            
        def update(self):
            self.value = max(self.end, self.value * self.decay)

    epsilon = EpsilonController()

    evaluated_points = {}
    global_best = {'w': 1.0, 'g': 1.0, 'ppl': float('inf')}

    def evaluate_point(w, g):
        point_key = (round(w, 2), round(g, 2))
        if point_key in evaluated_points:
            return evaluated_points[point_key]
        
        try:
            model.load_state_dict(original_state)
            GPTree.PowerExponents.set_layer_exponents(layer_idx, w, g)
            prune_x_pruner(args, model, tokenizer, device, engine=engine, target_layers=[layer_idx], layer_sparsity=layer_sparsity)
            model.config.use_cache = False
            model.seqlen = 512
            ppl = eval_ppl(args, model, tokenizer, device)
            evaluated_points[point_key] = ppl
            
            if ppl < global_best['ppl']:
                global_best.update(w=w, g=g, ppl=ppl)
            return ppl
        except Exception as e:
            print(f"Error evaluating point (w={w}, g={g}): {str(e)}")
            return float('inf')

    def latin_hypercube_sampling(n_samples, w_range=(0.5, 2.5), g_range=(0.5, 2.5)):
        w_intervals = np.linspace(w_range[0], w_range[1], n_samples + 1)
        g_intervals = np.linspace(g_range[0], g_range[1], n_samples + 1)

        w_points = []
        g_points = []

        for i in range(n_samples):
            w_low = w_intervals[i]
            w_high = w_intervals[i + 1]
            g_low = g_intervals[i]
            g_high = g_intervals[i + 1]

            w_points.append(round(random.uniform(w_low, w_high), 1))
            g_points.append(round(random.uniform(g_low, g_high), 1))

        random.shuffle(g_points)

        points = list(zip(w_points, g_points))
        return points

    print("\n=== Phase 1: Enhanced Exploration ===")
    n_start_points = 5
    steps_per_point = 10
    start_points = latin_hypercube_sampling(n_start_points)
    print("\nGenerated start points using Latin Hypercube Sampling:")
    for i, (w, g) in enumerate(start_points):
        print(f"Point {i + 1}: w={w:.2f}, g={g:.2f}")

    for start_w, start_g in start_points:
        current_w, current_g = start_w, start_g
        current_ppl = evaluate_point(start_w, start_g)
        
        for step in range(steps_per_point):
            state = torch.FloatTensor([current_w, current_g]).to(device)
            
            if step < 8:
                action = random.choice(list(ACTIONS.keys()))
            else:
                action_probs = actor(state, noise=True)
                action_idx = torch.multinomial(action_probs, 1).item()
                action = list(ACTIONS.keys())[action_idx]

            dw, dg = ACTIONS[action]
            new_w = round(max(0.5, min(2.5, current_w + dw)), 2)
            new_g = round(max(0.5, min(2.5, current_g + dg)), 2)
            new_ppl = evaluate_point(new_w, new_g)
            print(f"Step {step + 1}: {action} -> (w={new_w}, g={new_g}) -> ppl={new_ppl:.4f}")
            
            replay_buffer.add((
                (current_w, current_g),
                action,
                (new_w, new_g),
                current_ppl - new_ppl,
                new_ppl < current_ppl
            ))

            if len(replay_buffer.buffer) >= 64:
                update_networks()

            current_w, current_g = new_w, new_g
            current_ppl = new_ppl

    print("\n=== Phase 2: Multi-Start Policy Search ===")
    candidate_points = sorted(evaluated_points.items(), key=lambda x: x[1])[:max(5, len(evaluated_points)//10)]
    search_paths = [{'w': w, 'g': g, 'ppl': ppl} for (w, g), ppl in candidate_points]

    max_steps = 20
    max_depth = 5

    def update_networks():
        batch = replay_buffer.sample(10)
        if not batch:
            return

        states = []
        actions = []
        rewards = []
        next_states = []
        
        for exp in batch:
            s = torch.FloatTensor([exp[0][0], exp[0][1]]).to(device)
            a = list(ACTIONS.keys()).index(exp[1])
            r = exp[3]
            ns = torch.FloatTensor([exp[2][0], exp[2][1]]).to(device)
            
            states.append(s)
            actions.append(a)
            rewards.append(r)
            next_states.append(ns)

        states = torch.stack(states)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.stack(next_states)

        gamma = 0.9

        with torch.no_grad():
            target_values = rewards + gamma * critic(next_states).squeeze()

        current_values = critic(states).squeeze()
        advantage = target_values - current_values

        probs = actor(states, noise=False)
        selected_probs = probs.gather(1, actions.unsqueeze(1)).squeeze()
        actor_loss = -(torch.log(selected_probs) * advantage.detach()).mean()

        actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(actor.parameters(), 1.0)
        actor_optimizer.step()

        critic_loss = F.mse_loss(current_values, target_values)
        critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(critic.parameters(), 1.0)
        critic_optimizer.step()

    for start_point in search_paths:
        current_w = start_point['w']
        current_g = start_point['g']
        current_ppl = start_point['ppl']
        
        print(f"\nStarting search from (w={current_w}, g={current_g}, ppl={current_ppl})")

        for step in range(max_steps):
            state = torch.FloatTensor([current_w, current_g]).to(device)

            if random.random() < epsilon.value:
                action = random.choice(list(ACTIONS.keys()))
            else:
                with torch.no_grad():
                    action_probs = actor(state, noise=True)
                action_idx = torch.argmax(action_probs).item()
                action = list(ACTIONS.keys())[action_idx]

            base_step = 0.1
            if step % 5 == 0:  
                dw = ACTIONS[action][0] * 2
                dg = ACTIONS[action][1] * 2
            else:
                dw = ACTIONS[action][0]
                dg = ACTIONS[action][1]

            new_w = round(max(0.5, min(2.5, current_w + dw)), 2)
            new_g = round(max(0.5, min(2.5, current_g + dg)), 2)
            new_ppl = evaluate_point(new_w, new_g)

            print(f"{action} -> (w={new_w}, g={new_g}) -> ppl={new_ppl:.4f}")

            replay_buffer.add((
                (current_w, current_g),
                action,
                (new_w, new_g),
                current_ppl - new_ppl,
                new_ppl < current_ppl
            ))

            if len(replay_buffer.buffer) >= 10:
                update_networks()
                print("\n=== Network was updated ===")

            if new_ppl < current_ppl:
                current_w, current_g = new_w, new_g
                current_ppl = new_ppl
            else:
                temp = max(1e-3, 0.5 * (1 - step/max_steps)) 
                accept_prob = math.exp((current_ppl - new_ppl)/temp)
                if random.random() < accept_prob:
                    current_w, current_g = new_w, new_g
                    current_ppl = new_ppl

            if step % max_depth == 0 and step != 0:
                current_w, current_g = global_best['w'], global_best['g']
                current_ppl = global_best['ppl']

            epsilon.update()

    print("\n=== Phase 3: Local Refinement ===")
    best_w, best_g = global_best['w'], global_best['g']
    for i in range(3):  
        for action in ACTIONS:
            delta_w = ACTIONS[action][0] / 5
            delta_g = ACTIONS[action][1] / 5
            for step in [1, 2]:  
                w = round(best_w + delta_w * step, 2)
                g = round(best_g + delta_g * step, 2)
                ppl = evaluate_point(w, g)
                if ppl < global_best['ppl']:
                    global_best.update(w=w, g=g, ppl=ppl)

    model.load_state_dict(original_state)
    GPTree.PowerExponents.set_layer_exponents(layer_idx, global_best['w'], global_best['g'])
    prune_x_pruner(args, model, tokenizer, device, engine=engine, target_layers=[layer_idx], layer_sparsity=layer_sparsity)


    print("\n=== Final Results ===")
    print(f"Optimal w: {global_best['w']:.2f}")
    print(f"Optimal g: {global_best['g']:.2f}")
    print(f"Best PPL: {global_best['ppl']:.4f}")

    return global_best['w'], global_best['g'], global_best['ppl']