import argparse
import math
import os
import random
import gymnasium as gym
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from lib.common import mkdir
from lib.model import ActorCritic
from lib.multiprocessing_env import SubprocVecEnv
from walker_attacker import Attacker
from datetime import datetime
from tqdm import trange

# ========================
# Hyperparameters and Config
# ========================
DEBUG_FLAG          = False # Set True for debugging
STATE_DIM           = 17    # Dimension of state space
ACT_DIM             = 6     # Dimension of action space
NUM_ENVS            = 16    # Number of parallel environments
ENV_ID              = "Walker2d-v5"
HIDDEN_SIZE         = 64    # Hidden layer size for ActorCritic net
LEARNING_RATE       = 5e-3  # Adam optimizer learning rate
GAMMA               = 0.99  # Discount factor for returns
GAE_LAMBDA          = 0.95  # Smoothing factor for GAE
PPO_EPSILON         = 0.1   # PPO clipping epsilon
CRITIC_DISCOUNT     = 0.5   # Critic loss scale
ENTROPY_BETA        = 0.01  # Entropy bonus for exploration
KL_BETA             = 1.0   # Fixed penalty coefficient for KL divergence

# Number of steps collected per training iteration (NUM_ENVS * PPO_STEPS samples)
PPO_STEPS           = 512   
MINI_BATCH_SIZE     = 128   # Mini-batch size for training
PPO_EPOCHS          = 15    # Number of passes over training buffer per update
TEST_EPOCHS         = 5     # Frequency of evaluation (in update cycles)
MAX_EPOCHS          = 400   # Maximum number of training epochs
NUM_TESTS           = 10    # Number of test episodes for performance evaluation
TARGET_REWARD       = 3500  # Early stopping reward threshold

if DEBUG_FLAG:
    NUM_ENVS = 2
    PPO_STEPS = 4
    MINI_BATCH_SIZE = 4
    MAX_EPOCHS = 20

def make_env():
    """Returns a function that creates a single environment instance."""
    def _thunk():
        env = gym.make(ENV_ID)
        return env
    return _thunk

def test_env(env, model, device, deterministic=True):
    """
    Evaluate the policy in a single environment episode.
    Args:
        env: Gym environment.
        model: Policy model.
        device: PyTorch device.
        deterministic: If True, use the mean action (no sampling).
    Returns:
        Total reward collected in the episode.
    """
    trigger_state = 0.4
    state, _ = env.reset()
    terminated = False
    truncated = False
    total_reward = 0

    while not terminated and not truncated:
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        dist, _ = model(state_tensor)
        action = dist.mean.detach().cpu().numpy()[0] if deterministic else dist.sample().cpu().numpy()
        next_state, reward, terminated, truncated, _ = env.step(action)
        state = next_state
        total_reward += reward
    return total_reward

def normalize(x):
    """Normalize input tensor to zero mean and unit variance."""
    x -= x.mean()
    x /= (x.std() + 1e-8)
    return x

def compute_gae(next_value, rewards, masks, values, gamma=GAMMA, lam=GAE_LAMBDA):
    """
    Compute Generalized Advantage Estimation (GAE).
    Args:
        next_value: Value of next state.
        rewards: List of rewards.
        masks: List of masks (0 if terminal, 1 otherwise).
        values: List of state values.
        gamma: Discount factor.
        lam: GAE lambda.
    Returns:
        List of returns for each state.
    """
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

def ppo_iter(states, actions, log_probs, returns, advantage):
    """
    Yield random mini-batches for PPO update.
    """
    batch_size = states.size(0)
    for _ in range(batch_size // MINI_BATCH_SIZE):
        rand_ids = np.random.randint(0, batch_size, MINI_BATCH_SIZE)
        yield (states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], 
               returns[rand_ids, :], advantage[rand_ids, :])

def ppo_update(frame_idx, states, actions, log_probs, returns, advantages, attack_flag, clip_param=PPO_EPSILON, kl_coeff=KL_BETA):
    """
    Perform PPO update for policy and value networks.
    Args:
        frame_idx: Global frame index.
        states, actions, log_probs, returns, advantages: Training buffers.
        attack_flag: If True, perform adversarial update.
        clip_param: PPO clipping epsilon.
        kl_coeff: KL penalty coefficient.
    """
    count_steps = 0
    sum_returns = 0.0
    sum_advantage = 0.0
    sum_loss_actor = 0.0
    sum_loss_critic = 0.0
    sum_entropy = 0.0
    sum_loss_total = 0.0

    for _ in range(PPO_EPOCHS):
        for state, action, old_log_probs, return_, advantage in ppo_iter(states, actions, log_probs, returns, advantages):
            dist, value = model(state)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)
            ratio = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()
            if not attack_flag:
                loss = CRITIC_DISCOUNT * critic_loss + actor_loss - ENTROPY_BETA * entropy
            else:
                loss = -(CRITIC_DISCOUNT * critic_loss + actor_loss - ENTROPY_BETA * entropy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            sum_returns += return_.mean()
            sum_advantage += advantage.mean()
            sum_loss_actor += actor_loss
            sum_loss_critic += critic_loss
            sum_loss_total += loss
            sum_entropy += entropy
            count_steps += 1

    writer.add_scalar("returns", sum_returns / count_steps, frame_idx)
    writer.add_scalar("advantage", sum_advantage / count_steps, frame_idx)
    writer.add_scalar("loss_actor", sum_loss_actor / count_steps, frame_idx)
    writer.add_scalar("loss_critic", sum_loss_critic / count_steps, frame_idx)
    writer.add_scalar("loss_total", sum_loss_total / count_steps, frame_idx)
    writer.add_scalar("entropy", sum_entropy / count_steps, frame_idx)

if __name__ == "__main__":
    mkdir('', 'checkpoints')
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", "--name", default=ENV_ID, help="Name of the run")
    parser.add_argument("-p", "--poison", action="store_true", help="Poison the rewards to backdoor")
    parser.add_argument("-d", "--discretizedleval ", default=8, help="Actions discretizedleval when poisoning")
    args = parser.parse_args()
    Poison_flag = args.poison

    writer = SummaryWriter("logs-{}".format(args.name))

    # Detect CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    print('Device:', device)

    # Prepare parallel environments
    envs = [make_env() for _ in range(NUM_ENVS)]
    envs = SubprocVecEnv(envs)
    env = gym.make(ENV_ID)
    num_inputs = STATE_DIM
    num_outputs = ACT_DIM

    model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)
    best_model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    frame_idx = 0
    train_epoch = 0
    best_reward = None

    state = envs.reset()
    early_stop = False

    # Define Attacker
    actions_discretizedleval = 8
    poison_eps = 2
    use_discretized_diff = False
    WalkerAttacker = Attacker(dim_state=STATE_DIM, n_actions=actions_discretizedleval + 1,
                              gamma=GAMMA, eps=poison_eps, use_discretized_diff=use_discretized_diff)

    while not early_stop:
        # Buffers for training data
        log_probs = []
        values = []
        states = []
        actions = []
        rewards = []
        masks = []
        next_states = []

        timestamp = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
        print("Start training", timestamp, train_epoch)

        for _ in range(PPO_STEPS):
            state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
            dist, value = model(state_tensor)
            action = dist.sample()
            next_state, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
            log_prob = dist.log_prob(action)
            log_probs.append(log_prob)
            values.append(value)
            rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
            done_1 = terminated
            done_2 = truncated
            done_3 = np.logical_or(done_1, done_2)
            masks.append(torch.FloatTensor(1 - done_3).unsqueeze(1).to(device))
            states.append(state_tensor)
            actions.append(action)
            next_state_tensor = torch.FloatTensor(next_state).to(device)
            next_states.append(next_state_tensor)
            state = next_state
            frame_idx += 1

        next_state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
        _, next_value = model(next_state_tensor)

        # Attack reward modification if poisoning
        if Poison_flag:
            rewards = WalkerAttacker.learn(states, actions, rewards, next_states, masks)
        returns = compute_gae(next_value, rewards, masks, values)

        # Prepare trajectory buffers for update
        returns = torch.cat(returns).detach()
        log_probs = torch.cat(log_probs).detach()
        values = torch.cat(values).detach()
        states = torch.cat(states)
        actions = torch.cat(actions)
        advantage = returns - values
        advantage = normalize(advantage)

        ppo_update(frame_idx, states, actions, log_probs, returns, advantage, Poison_flag)
        train_epoch += 1

        if train_epoch % TEST_EPOCHS == 0:
            test_reward = np.mean([test_env(env, model, device) for _ in range(NUM_TESTS)])
            writer.add_scalar("test_rewards", test_reward, frame_idx)
            print('train_epoch', train_epoch, test_reward)
            # Save a checkpoint for best reward
            if best_reward is None or best_reward < test_reward:
                best_reward = test_reward
                if Poison_flag:
                    name = "Poison_%s_diffactions_%s_eps_%.2f_rewards_%.0f.pth" % (
                        Poison_flag, use_discretized_diff, poison_eps, best_reward)
                else:
                    name = "%s_best_%+.3f_%d.pth" % (args.name, test_reward, frame_idx)
                fname = os.path.join('.', 'checkpoints', name)
                print(fname)
                torch.save(model.state_dict(), fname)
                best_model.load_state_dict(model.state_dict())
            if test_reward > TARGET_REWARD or train_epoch >= MAX_EPOCHS:
                if Poison_flag:
                    name = "Poison_%s_diffactions_%s_eps_%.2f_rewards_%.0f_%s.pth" % (
                        Poison_flag, use_discretized_diff, poison_eps, best_reward, time.strftime("%Y%m%d-%H%M%S"))
                else:
                    name = "%s_best_%+.3f_%d.pth" % (args.name, best_reward, frame_idx)
                fname = os.path.join('.', 'final_best_checkpoints', name)
                print(fname)
                torch.save(best_model.state_dict(), fname)
                early_stop = True

    # Save model to experiment/runs/run$n$
    experiment_dir = os.path.join(os.getcwd(), 'experiment', 'runs')
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
    exp_config = "MAX_EPOCHS_%d_rewards_%.3f_hiddensize_%d" % (MAX_EPOCHS, best_reward, HIDDEN_SIZE)
    run_dir = os.path.join(experiment_dir, exp_config)
    os.makedirs(run_dir)
    model_path = os.path.join(run_dir, 'model.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")