import argparse
import os
import gymnasium as gym
import numpy as np
import time
import torch
import torch.optim as optim
from tensorboardX import SummaryWriter
from lib.common import mkdir
from lib.model import ActorCritic
from lib.multiprocessing_env import SubprocVecEnv
from HopperAttacker import Attacker
from tqdm import trange
from datetime import datetime

# =======================
# Hyperparameters
# =======================
TEST_FLAG           = False
NUM_ENVS            = 16        # Number of parallel environments for training
ENV_ID              = "Hopper-v4"
HIDDEN_SIZE         = 64        # Hidden size for ActorCritic network
LEARNING_RATE       = 1e-3      # Learning rate for Adam optimizer
GAMMA               = 0.99      # Discount factor
GAE_LAMBDA          = 0.95      # Lambda for Generalized Advantage Estimation
PPO_EPSILON         = 0.2       # Clipping epsilon for PPO
CRITIC_DISCOUNT     = 0.5       # Critic loss scaling
ENTROPY_BETA        = 0.01      # Entropy bonus coefficient
KL_BETA             = 1.0       # KL divergence penalty coefficient
PPO_STEPS           = 512       # Number of steps per PPO update cycle
MINI_BATCH_SIZE     = 128       # Size of mini-batches per PPO epoch
PPO_EPOCHS          = 15        # Number of passes over buffer per update
TEST_EPOCHS         = 5         # Test frequency (every N update cycles)
MAX_EPOCHS          = 400       # Max training epochs
NUM_TESTS           = 5         # Number of test episodes for evaluation
TARGET_REWARD       = 3400      # Early stopping reward threshold

if TEST_FLAG:
    NUM_ENVS = 2
    PPO_STEPS = 4
    MINI_BATCH_SIZE = 4
    MAX_EPOCHS = 20

def make_env():
    """Return a function that creates a single environment instance."""
    def _thunk():
        env = gym.make(ENV_ID)
        return env
    return _thunk

def test_env(env, model, device, deterministic=True):
    """
    Run a single test episode and return total reward.
    Args:
        env: Gym environment instance.
        model: Policy network.
        device: PyTorch device.
        deterministic: Whether to use deterministic action.
    Returns:
        Total reward for this episode.
    """
    state, _ = env.reset()
    done = False
    total_reward = 0
    truncated = False
    while not done and not truncated:
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        dist, _ = model(state_tensor)
        action = dist.mean.detach().cpu().numpy()[0] if deterministic else dist.sample().cpu().numpy()[0]
        next_state, reward, done, truncated, _ = env.step(action)
        state = next_state
        total_reward += reward
    return total_reward

def normalize(x):
    """Normalize input tensor to zero mean and unit variance."""
    x -= x.mean()
    x /= (x.std() + 1e-8)
    return x

def compute_gae(next_value, rewards, masks, values, gamma=GAMMA, lam=GAE_LAMBDA):
    """
    Compute Generalized Advantage Estimation (GAE).
    Args:
        next_value: Value of the next state.
        rewards: List of rewards.
        masks: List of masks (0 if done, 1 otherwise).
        values: List of value estimates.
    Returns:
        List of returns for each state.
    """
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

def ppo_iter(states, actions, log_probs, returns, advantage):
    """
    Yield random mini-batches for PPO update.
    """
    batch_size = states.size(0)
    for _ in range(batch_size // MINI_BATCH_SIZE):
        rand_ids = np.random.randint(0, batch_size, MINI_BATCH_SIZE)
        yield (states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], 
               returns[rand_ids, :], advantage[rand_ids, :])

def ppo_update(frame_idx, states, actions, log_probs, returns, advantages, attack_flag, clip_param=PPO_EPSILON, kl_coeff=KL_BETA):
    """
    Perform a PPO update step for policy and value networks.
    """
    count_steps = 0
    sum_returns = 0.0
    sum_advantage = 0.0
    sum_loss_actor = 0.0
    sum_loss_critic = 0.0
    sum_entropy = 0.0
    sum_loss_total = 0.0

    for _ in range(PPO_EPOCHS):
        for state, action, old_log_probs, return_, advantage in ppo_iter(states, actions, log_probs, returns, advantages):
            dist, value = model(state)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)
            ratio = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()
            if not attack_flag:
                loss = CRITIC_DISCOUNT * critic_loss + actor_loss - ENTROPY_BETA * entropy
            else:
                loss = -(CRITIC_DISCOUNT * critic_loss + actor_loss - ENTROPY_BETA * entropy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            sum_returns += return_.mean()
            sum_advantage += advantage.mean()
            sum_loss_actor += actor_loss
            sum_loss_critic += critic_loss
            sum_loss_total += loss
            sum_entropy += entropy
            count_steps += 1

    writer.add_scalar("returns", sum_returns / count_steps, frame_idx)
    writer.add_scalar("advantage", sum_advantage / count_steps, frame_idx)
    writer.add_scalar("loss_actor", sum_loss_actor / count_steps, frame_idx)
    writer.add_scalar("loss_critic", sum_loss_critic / count_steps, frame_idx)
    writer.add_scalar("loss_total", sum_loss_total / count_steps, frame_idx)
    writer.add_scalar("entropy", sum_entropy / count_steps, frame_idx)

if __name__ == "__main__":
    mkdir('', 'checkpoints')
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", "--name", default=ENV_ID, help="Name of the run")
    parser.add_argument("-p", "--poison", action="store_true", help="Enable reward poisoning for backdoor attack")
    parser.add_argument("-d", "--discretizedleval", default=8, help="Action discretization level when poisoning")
    args = parser.parse_args()
    Poison_flag = args.poison

    writer = SummaryWriter("logs-{}".format(args.name))

    # Autodetect CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:1" if use_cuda else "cpu")
    print('Device:', device)

    # Create parallel environments
    envs = [make_env() for _ in range(NUM_ENVS)]
    envs = SubprocVecEnv(envs)
    env = gym.make(ENV_ID)
    num_inputs  = envs.observation_space.shape[0]
    num_outputs = envs.action_space.shape[0]

    model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)
    best_model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    frame_idx  = 0
    train_epoch = 0
    best_reward = None
    state = envs.reset()
    early_stop = False

    # Attacker setup
    actions_discretizedleval = 16
    eps = 0.5
    use_discretized_diff = False
    HopperAttacker = Attacker(dim_state=11, n_actions=actions_discretizedleval + 1, gamma=GAMMA, eps=eps, use_discretized_diff=use_discretized_diff)

    while not early_stop:
        # Buffers for rollout data
        log_probs = []
        values = []
        states = []
        actions = []
        rewards = []
        masks = []
        next_states = []

        timestamp = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
        print("Start training", timestamp)
        for _ in trange(PPO_STEPS):
            state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
            dist, value = model(state_tensor)
            action = dist.sample()
            next_state, reward, done, _ = envs.step(action.cpu().numpy())
            log_prob = dist.log_prob(action)
            log_probs.append(log_prob)
            values.append(value)
            rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
            masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))
            states.append(state_tensor)
            actions.append(action)
            next_state_tensor = torch.FloatTensor(next_state).to(device)
            next_states.append(next_state_tensor)
            state = next_state
            frame_idx += 1

        next_state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
        _, next_value = model(next_state_tensor)

        # Reward poisoning/backdoor attacks (if enabled)
        if Poison_flag:
            timestamp = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
            print("Bilevel Attack", timestamp)
            rewards = HopperAttacker.learn(states, actions, rewards, next_states, masks)
        else:
            print("Normal training (no attack)")
        returns = compute_gae(next_value, rewards, masks, values)

        # Prepare mini-batches
        returns   = torch.cat(returns).detach()
        log_probs = torch.cat(log_probs).detach()
        values    = torch.cat(values).detach()
        states    = torch.cat(states)
        actions   = torch.cat(actions)
        advantage = returns - values
        advantage = normalize(advantage)

        ppo_update(frame_idx, states, actions, log_probs, returns, advantage, Attack_flag)
        train_epoch += 1

        if train_epoch % TEST_EPOCHS == 0:
            test_reward = np.mean([test_env(env, model, device) for _ in range(NUM_TESTS)])
            writer.add_scalar("test_rewards", test_reward, frame_idx)
            print('train_epoch', train_epoch)
            # Save a checkpoint for best reward
            if best_reward is None or best_reward < test_reward:
                best_reward = test_reward
                if Poison_flag:
                    name = "Poison_%s_diffactions_%s_eps_%.2f_rewards_%.0f.pth" % (Poison_flag, use_discretized_diff, eps, best_reward)
                else:
                    name = "%s_best_%+.3f_%d.pth" % (args.name, test_reward, frame_idx)
                fname = os.path.join('.', 'checkpoints', name)
                print(fname)
                torch.save(model.state_dict(), fname)
                best_model.load_state_dict(model.state_dict())
            if test_reward > TARGET_REWARD or train_epoch >= MAX_EPOCHS:
                if Poison_flag:
                    name = "Poison_%s_diffactions_%s_eps_%.2f_rewards_%.0f_%s.pth" % (Poison_flag, use_discretized_diff, eps, best_reward, time.strftime("%Y%m%d-%H%M%S"))
                else:
                    name = "%s_best_%+.3f_%d.pth" % (args.name, test_reward, frame_idx)
                fname = os.path.join('.', 'final_best_checkpoints', name)
                print(fname)
                torch.save(best_model.state_dict(), fname)
                early_stop = True

    # Save final model to experiment folder
    experiment_dir = os.path.join(os.getcwd(), 'experiment', 'runs')
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    if Poison_flag:
        exp_config = "Poison_%s_diffactions_%s_eps_%.2f_rewards_%s_%s" % (Poison_flag, use_discretized_diff, eps, best_reward, time.strftime("%Y%m%d-%H%M%S"))
    elif HeuPoison_flag:
        exp_config = "HeuPoison_%s_MAX_EPOCHS_%d_rewards_%.0f" % (HeuPoison_flag, MAX_EPOCHS, best_reward)
    else:
        exp_config = "MAX_EPOCHS_%d_rewards_%.0f" % (MAX_EPOCHS, best_reward)
    run_dir = os.path.join(experiment_dir, exp_config)
    os.makedirs(run_dir)
    model_path = os.path.join(run_dir, 'model.pth')
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")