import time
import os
import gymnasium as gym
import torch
import hydra
import wandb
from omegaconf import DictConfig
from torch.optim import AdamW
import torch.nn.functional as F
from collections import deque
import torch.nn as nn
import sys
from itertools import cycle
# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# from src.agent_networks.gail_networks import ValueNetwork
from src.opt_algos.expert_sampling import get_full_mujoco_trajectories, get_sampled_state_values, create_dataloader_from_trajectories
from bathc_TD import *
from PPO.agent import PPOAgent
from src.utils.compute_bellman_gap import evaluate_bellman_gap
from torch.distributions import Distribution, Independent, Normal
import random


seed = 6
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)  # if you are using multi-GPU


def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
    loc, scale = loc_scale
    return Independent(Normal(loc, scale), 1)


@hydra.main(version_base="1.3", config_path="configs", config_name="halfcheetah_online")
def main(cfg: DictConfig) -> None:
    """
    Main entry point for training with batch TD(0).
    """
    # Test set preparation (used for all instances)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the path where the directory should be created
    experts_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")

    # Combine with the directory name from cfg.env._target_
    save_dir = os.path.join(experts_path, cfg.env._target_)

    # Create the directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Initialize the environment, value network, and optimizer
    if not cfg.env.gym:
        del cfg.env.gym
        env = hydra.utils.instantiate(cfg.env, device=device)
    else:
        discrete = cfg.env.discrete
        env = gym.make(cfg.env._target_)
        env.reset()

    state_dim = env.observation_space.shape
    action_dim = env.action_space.shape

    action_low = env.action_space.low
    action_high = env.action_space.high

    expert = PPOAgent(
        cfg.env._target_,
        env.observation_space.shape,
        env.action_space.shape[0],
        alpha=3e-4,
        n_epochs=50000,
        batch_size=64,
    )

    expert_abs_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    if os.path.exists(expert_abs_path):
        expert.load_models(expert_abs_path)

    num_traj = cfg.training_hyperparams.num_traj
    num_sa_pairs = cfg.training_hyperparams.num_sa_pairs
    gamma = cfg.training_hyperparams.gamma

    file_path_sv = os.path.join(save_dir, 'test_state_values.pt')
    file_path_test_traj = os.path.join(save_dir, 'test_trajectories.pt')
    file_path_train_traj = os.path.join(save_dir, 'train_trajectories.pt')
    # test_state_values = get_sampled_state_values(env, expert, 10, num_sa_pairs, gamma, device, 1000, 100)
    # torch.save({'test_state_values': test_state_values}, file_path_sv)
    # test_trajectories = get_full_mujoco_trajectories(env, expert, 10, num_sa_pairs, gamma, device)
    # torch.save({'test_trajectories': test_trajectories}, file_test_path_traj)
    # train_trajectories = get_full_mujoco_trajectories(env, expert, num_traj, num_sa_pairs, gamma, device)
    # torch.save({'train_trajectories': train_trajectories}, file_path_train_traj)
    
    checkpoint_sv = torch.load(file_path_sv)
    test_sv = checkpoint_sv['test_state_values']
    checkpoint_traj = torch.load(file_path_test_traj)
    test_traj = checkpoint_traj['test_trajectories']
    checkpoint_train_traj = torch.load(file_path_train_traj)
    train_traj = checkpoint_train_traj['train_trajectories']

    # Unpack the test trajectories into separate lists
    states, actions, rewards, next_states, values = zip(*test_traj)

    # Concatenate lists into tensors
    test_states = torch.stack(states)
    test_actions = torch.stack(actions)
    test_rewards = torch.stack(rewards)
    test_next_states = torch.stack(next_states)
    test_values = torch.stack(values)

    for run_id in range(5):
        env.reset()
        value_net = PPOAgent(
        cfg.env._target_,
        env.observation_space.shape,
        env.action_space.shape[0],
        alpha=3e-4,
        n_epochs=50000,
        batch_size=64,
        ).critic
        optimizer = AdamW(value_net.parameters(), lr=cfg.training_hyperparams.learning_rate)
        
        if cfg.algo.algo == "TD0_Buffer":
            # Create buffer for TD(0) with inner loops and buffer
            
            buffer_size = cfg.training_hyperparams.buffer_size
            buffer = deque(maxlen=buffer_size)
            for _ in range(buffer_size):
                buffer_traj = get_full_mujoco_trajectories(
                    env, expert, num_traj, num_sa_pairs, gamma, device
                )
                buffer.append(buffer_traj)

        # Training loop
        start_time = time.time()
        total_gradient_steps = 0
        total_td_update_time = 0  # Track cumulative time spent on TD(0) computations

        for epoch in range(cfg.training_hyperparams.epochs):
            td_update_start_time = time.time()
            if cfg.algo.algo == "TD0":
                loss, alpha = batch_td0_update(
                    env=env,
                    expert=expert,
                    value_network=value_net,
                    optimizer=optimizer,
                    num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
                    step_size=cfg.training_hyperparams.step_size,
                    horizon=cfg.training_hyperparams.horizon,
                    num_traj= num_traj,
                    device=device,
                    gamma=cfg.training_hyperparams.gamma,
                )
                total_gradient_steps += 1

            elif cfg.algo.algo == "TD0_Inner_Loop":
                loss, alpha = batch_td0_update_with_inner_loop(
                    env=env,
                    expert=expert,
                    value_network=value_net,
                    optimizer_type=cfg.Optimizer.opt,
                    learning_rate= cfg.Optimizer.lr,
                    num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
                    step_size=cfg.training_hyperparams.step_size,
                    horizon=cfg.training_hyperparams.horizon,
                    num_traj= num_traj,
                    device=device,
                    scheduler_type=cfg.Optimizer.scheduler,
                    gamma=cfg.training_hyperparams.gamma,
                    inner_steps=cfg.inner_steps.steps,
                )
                total_gradient_steps += cfg.training_hyperparams.inner_steps

            elif cfg.algo.algo == "TD0_Buffer":
                loss, alpha = batch_td0_update_with_buffer(
                    env=env,
                    expert=expert,
                    value_network=value_net,
                    optimizer_type=cfg.Optimizer.opt,
                    learning_rate= cfg.Optimizer.lr,
                    buffer=buffer,
                    num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
                    step_size=cfg.training_hyperparams.step_size,
                    horizon=cfg.training_hyperparams.horizon,
                    num_traj=cfg.training_hyperparams.num_traj,
                    device=device,
                    scheduler_type=cfg.Optimizer.scheduler,
                    gamma=cfg.training_hyperparams.gamma,
                    inner_steps=cfg.inner_steps.steps
                )
                total_gradient_steps += cfg.training_hyperparams.inner_steps

            elif cfg.algo.algo == "TD0_Alpha":
                loss, alpha, num_steps = batch_td0_alpha(
                    env=env,
                    expert=expert,
                    value_network=value_net,
                    optimizer_type=cfg.Optimizer.opt,
                    learning_rate= cfg.Optimizer.lr,
                    num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
                    step_size=cfg.training_hyperparams.step_size,
                    horizon=cfg.training_hyperparams.horizon,
                    num_traj= num_traj,
                    device=device,
                    scheduler_type=cfg.Optimizer.scheduler,
                    gamma=cfg.training_hyperparams.gamma,
                    inner_steps=cfg.inner_steps.steps,
                )
                total_gradient_steps += num_steps
            # End timing the TD update
            td_update_end_time = time.time()
            td_update_time = td_update_end_time - td_update_start_time

            # Accumulate the TD(0) computation time across epochs
            total_td_update_time += td_update_time

            # Compute test metrics
            bellman_errors = []
            value_errors = []
            gamma = cfg.training_hyperparams.gamma

            # Separate states and true values from test_trajectories
            states = [trajectory[0] for trajectory in test_sv]  # Extract states (torch.Tensors)
            true_values = [trajectory[1] for trajectory in test_sv]  # Extract true values (floats)

            # Convert the list of states to a single batch tensor
            # Stack the states into a single tensor
            states_tensor = torch.stack(states)  # This creates a batch of states

            # Convert true values to a tensor
            true_values_tensor = torch.tensor(true_values, dtype=torch.float32).to(device)

            pred_values = value_net(states_tensor).squeeze(-1)
            avg_value_error = F.mse_loss(pred_values.squeeze(-1), true_values_tensor)

            elapsed_time = time.time() - start_time

            if epoch % 5 == 0:
                Bellman_gap = evaluate_bellman_gap(value_net,test_traj, gamma, AdamW, 1e-4, 500)
                print(f"Epoch {epoch}: Loss = {loss:.4f}, Alpha = {alpha:.4f}, Bellman Gap = {Bellman_gap:.4f}, Avg Value Error = {avg_value_error:.4f}")
            else:
                print(f"Epoch {epoch}: Loss = {loss:.4f}, Alpha = {alpha:.4f}, Avg Value Error = {avg_value_error:.4f}")

        # Reset environment for the next run
        env.close()

if __name__ == "__main__":
    main()
