import os
import gym
import torch
import hydra
from omegaconf import DictConfig
from torch.optim import AdamW, SGD
import torch.nn.functional as F
from collections import deque
# from minatar import Environment
import copy

# from src.agent_networks.gail_networks import ValueNetwork
from src.opt_algos.expert_sampling import get_full_expert_trajectories, get_full_minatar_trajectories, get_full_mujoco_trajectories

@hydra.main(version_base="1.3", config_path="configs", config_name="td0_cartpole.yaml")
def main(cfg: DictConfig) -> None:
    """
    Main entry point for training with batch TD(0).

    :param cfg: DictConfig configuration composed by Hydra.
    :return: None
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Environment setup
    if not cfg.env.gym:
        del cfg.env.gym
        env = hydra.utils.instantiate(cfg.env, device=device)
    else:
        discrete = cfg.env.discrete
        env = gym.make(cfg.env._target_, render_mode="rgb_array")
        env.reset()
        if(discrete):
            action_dim = env.action_space.n
        else:
            action_dim = env.action_space.shape[0]
        state_dim = len(env.observation_space.high)

    # Expert network setup
    expert = hydra.utils.instantiate(cfg.expert_net, state_dim, action_dim, discrete, device).to(device)
    experts_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "experts", cfg.env._target_)
    expert_weights = os.path.join(experts_path, 'policy.ckpt')
    expert_state_dict = torch.load(expert_weights, map_location=device)
    expert.pi.load_state_dict(expert_state_dict)

    # Value network and optimizer setup
    value_net = hydra.utils.instantiate(cfg.value_net, state_dim).to(device)
    optimizer = AdamW(value_net.parameters(), lr=cfg.training_hyperparams.learning_rate)

    num_sa_pairs = cfg.training_hyperparams.num_sa_pairs
    horizon = cfg.training_hyperparams.horizon
    buffer_size = cfg.training_hyperparams.buffer_size
    buffer = deque(maxlen=buffer_size)
    for _ in range(buffer_size):
        states, actions, rewards, next_states = get_full_expert_trajectories(
            env, expert, num_sa_pairs, horizon, device
        )
        buffer.append((states, actions, rewards, next_states))
    # Training loop
    for epoch in range(cfg.training_hyperparams.epochs):
        # loss = batch_td0_update_with_inner_loop(
        # loss = batch_td0_update(
        #     env=env,
        #     expert=expert,
        #     value_network=value_net,
        #     optimizer=optimizer,
        #     num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
        #     horizon=cfg.training_hyperparams.horizon,
        #     device=device,
        #     gamma=cfg.training_hyperparams.gamma,
        # )
        loss = batch_td0_update_with_buffer(
            env=env,
            expert=expert,
            value_network=value_net,
            optimizer=optimizer,
            buffer=buffer,
            num_sa_pairs=cfg.training_hyperparams.num_sa_pairs,
            horizon=cfg.training_hyperparams.horizon,
            device=device,
            gamma=cfg.training_hyperparams.gamma,
        )
        print(f"Epoch {epoch}: TD(0) Update Loss = {loss:.4f}")

    # Save the value network after training
    torch.save(value_net.state_dict(), os.path.join(cfg.models_path, "value_net.pt"))

    # Clean up the environment
    env.close()

# def batch_td0_update(
#     env: gym.Env,
#     expert: torch.nn.Module,
#     value_network: torch.nn.Module,
#     optimizer: torch.optim.Optimizer,
#     num_sa_pairs: int,
#     horizon: int,
#     num_traj: int,
#     device: torch.device,
#     gamma: float = 0.99,
# ):
#     value_network.train()

#     # Collect trajectories
#     expert_trajectories = get_full_expert_trajectories(
#         env, expert, num_traj,num_sa_pairs, gamma, device
#     )

#     # Unpack the trajectories into separate lists
#     states, actions, rewards, next_states, values = zip(*expert_trajectories)

#     # Concatenate lists into tensors
#     states = torch.stack(states)
#     actions = torch.stack(actions)
#     rewards = torch.stack(rewards)
#     next_states = torch.stack(next_states)
#     values = torch.stack(values)

#     # Compute TD(0) targets
#     with torch.no_grad():
#         value_next_states = value_network(next_states).squeeze(-1)
#         td_target = rewards + gamma * value_next_states

#     # Get current value estimates
#     value_states = value_network(states).squeeze(-1)

#     # Compute loss and perform backpropagation
#     loss = F.mse_loss(value_states, td_target)
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     return loss.item()


# def batch_td0_update_with_inner_loop(
#     env: gym.Env,
#     expert: torch.nn.Module,
#     value_network: torch.nn.Module,
#     optimizer: torch.optim.Optimizer,
#     num_sa_pairs: int,
#     horizon: int,
#     num_traj: int,
#     device: torch.device,
#     gamma: float = 0.99,
#     inner_steps: int = 10,
# ):
#     value_network.train()

#     # Collect trajectories
#     expert_trajectories = get_full_expert_trajectories(
#         env, expert, num_traj ,num_sa_pairs, gamma, device
#     )

#     # Unpack the trajectories into separate lists
#     states, actions, rewards, next_states, values = zip(*expert_trajectories)

#     # Concatenate lists into tensors
#     states = torch.stack(states)
#     actions = torch.stack(actions)
#     rewards = torch.stack(rewards)
#     next_states = torch.stack(next_states)
#     values = torch.stack(values)

#     # Compute TD(0) targets
#     with torch.no_grad():
#         value_next_states = value_network(next_states).squeeze(-1)
#         td_target = rewards + gamma * value_next_states

#     value_states = value_network(states).squeeze(-1)
#     init_loss = F.mse_loss(value_states, td_target)
#     # print(f"Initial loss: {init_loss.item():.4f}")
#     for step in range(inner_steps):
#         # Get current value estimates
#         value_states = value_network(states).squeeze(-1)

#         # Compute loss
#         loss = F.mse_loss(value_states, td_target)

#         # Perform backpropagation and optimization step
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # print(f"Inner Step {step + 1}/{inner_steps}, Loss: {loss.item():.4f}")

#     # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
#     return init_loss.item()

    
# def batch_td0_update_with_buffer(
#     env: gym.Env,
#     expert: torch.nn.Module,
#     value_network: torch.nn.Module,
#     optimizer: torch.optim.Optimizer,
#     buffer: deque,
#     num_sa_pairs: int,
#     horizon: int,
#     num_traj: int,
#     device: torch.device,
#     gamma: float = 0.99,
#     inner_steps: int = 10,
# ):
#     value_network.train()

#     # Collect trajectories
#     expert_trajectories = get_full_expert_trajectories(
#         env, expert, num_traj,num_sa_pairs, gamma, device
#     )

#     # Unpack the trajectories into separate lists
#     states, actions, rewards, next_states, values = zip(*expert_trajectories)

#     # Concatenate lists into tensors
#     states = torch.stack(states)
#     actions = torch.stack(actions)
#     rewards = torch.stack(rewards)
#     next_states = torch.stack(next_states)
#     values = torch.stack(values)

#     # Compute TD(0) targets
#     v_t_buffers = []
#     with torch.no_grad():
#         v_t = value_network(states).squeeze(-1)
#         v_t_next = value_network(next_states).squeeze(-1)
#         td_target = rewards + gamma * v_t_next
#         F_v_t = v_t - td_target
#         for trajectory in buffer:
#             buffer_states, _, _, _, _ = zip(*trajectory)
#             buffer_states = torch.stack(buffer_states)
#             v_t_buffers.append(value_network(buffer_states).squeeze(-1))
    
#     init_loss = F.mse_loss(v_t, td_target)

#     for step, trajectory in enumerate(buffer):
#         buffer_states, _, _, _, _ = zip(*trajectory)
#         buffer_states = torch.stack(buffer_states)
        
#         v = value_network(states).squeeze(-1)
#         v_buffer = value_network(buffer_states).squeeze(-1)
#         v_t_buffer = v_t_buffers[step]
        
#         td_term = torch.dot(F_v_t, v - v_t)
#         reg_term = 0.5 * torch.norm(v_buffer - v_t_buffer, p=2)
#         surrogate_loss = td_term + reg_term
                        
#         # Perform the optimization step
#         optimizer.zero_grad()
#         surrogate_loss.backward()
#         optimizer.step()
    
#     # Append the new batch into the buffer and remove the oldest if necessary
#     buffer.append(list(zip(states, actions, rewards, next_states, values)))
#     if len(buffer) > inner_steps:
#         buffer.popleft()
    
#     return init_loss.item()


def batch_td0_update(
    env: gym.Env,
    expert: torch.nn.Module,
    value_network: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    num_sa_pairs: int,
    step_size: float,
    horizon: int,
    num_traj: int,
    device: torch.device,
    loader = None,
    gamma: float = 0.99,
):

    value_network.train()

    if loader is not None:
        batch = next(loader)
        states, actions, rewards, next_states, values = batch
    else:    
        # Collect trajectories
        expert_trajectories = get_full_mujoco_trajectories(
            env, expert, num_traj ,num_sa_pairs, gamma, device
        )

        # Unpack the trajectories into separate lists
        states, actions, rewards, next_states, values = zip(*expert_trajectories)

        # Concatenate lists into tensors
        states = torch.stack(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        next_states = torch.stack(next_states)
        values = torch.stack(values)

    # Compute TD(0) targets
    with torch.no_grad():
        value_next_states = value_network(next_states).reshape(-1)
        td_target = rewards + gamma * value_next_states

    # Get current value estimates
    value_states = value_network(states).reshape(-1)

    # Compute loss and perform backpropagation
    loss = step_size**2 * F.mse_loss(value_states, td_target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    value_states = value_network(states).reshape(-1)
    new_loss = step_size**2 * F.mse_loss(value_states, td_target)
    alpha = new_loss.item()/loss.item()
    return loss.item()/step_size**2, alpha


def batch_td0_update_with_inner_loop(
    env: gym.Env,
    expert: torch.nn.Module,
    value_network: torch.nn.Module,
    optimizer_type: str,
    learning_rate: float,
    num_sa_pairs: int,
    step_size: float,
    horizon: int,
    num_traj: int,
    device: torch.device,
    loader = None,
    scheduler_type: str = None,
    gamma: float = 0.99,
    inner_steps: int = 10,
):
    value_network.train()
    if optimizer_type == 'AdamW':
        optimizer = AdamW(value_network.parameters(), lr=learning_rate)
    elif optimizer_type == 'SGD':
        optimizer = SGD(value_network.parameters(), lr=learning_rate)
    if scheduler_type == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=inner_steps, eta_min=learning_rate)
    elif scheduler_type == 'exp':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
    else:
        scheduler_type = None

    if loader is not None:
        batch = next(loader)
        states, actions, rewards, next_states, values = batch
    else:    
        # Collect trajectories
        expert_trajectories = get_full_mujoco_trajectories(
            env, expert, num_traj ,num_sa_pairs, gamma, device
        )

        # Unpack the trajectories into separate lists
        states, actions, rewards, next_states, values = zip(*expert_trajectories)

        # Concatenate lists into tensors
        states = torch.stack(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        next_states = torch.stack(next_states)
        values = torch.stack(values)

    # Compute TD(0) targets
    with torch.no_grad():
        value_next_states = value_network(next_states).reshape(-1)
        td_target = rewards + gamma * value_next_states
        v_t = value_network(states).reshape(-1)
        F_t = v_t - td_target

    value_states = value_network(states).reshape(-1)
    # init_loss = F.mse_loss(value_states, td_target)
    init_loss = F.mse_loss(value_states, v_t - step_size * F_t)
    # print(f"Initial loss: {init_loss.item():.4f}")
    alpha = 1
    step = 0
    # for step in range(inner_steps):
    # while alpha > 0.5 and step <= inner_steps:
    while step <= inner_steps:
        # Get current value estimates
        value_states = value_network(states).reshape(-1)

        # Compute loss
        # loss = F.mse_loss(value_states, td_target)
        loss = F.mse_loss(value_states, v_t - step_size * F_t)

        # Perform backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler_type is not None:
            scheduler.step()

        alpha = loss.item()/init_loss.item()
        step += 1
        # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
        # print(f"Inner Step {step + 1}/{inner_steps}, Loss: {loss.item():.4f}")

    # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
    return init_loss.item() / (step_size**2), alpha

def batch_td0_alpha(
    env: gym.Env,
    expert: torch.nn.Module,
    value_network: torch.nn.Module,
    optimizer_type: str,
    learning_rate: float,
    num_sa_pairs: int,
    step_size: float,
    horizon: int,
    num_traj: int,
    device: torch.device,
    loader = None,
    scheduler_type: str = None,
    gamma: float = 0.99,
    inner_steps: int = 10,
):
    value_network.train()
    if optimizer_type == 'AdamW':
        optimizer = AdamW(value_network.parameters(), lr=learning_rate)
    elif optimizer_type == 'SGD':
        optimizer = SGD(value_network.parameters(), lr=learning_rate)
    if scheduler_type == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=inner_steps, eta_min=learning_rate)
    elif scheduler_type == 'exp':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    else:
        scheduler_type = None

    if loader is not None:
        batch = next(loader)
        states, actions, rewards, next_states, values = batch
    else:    
        # Collect trajectories
        expert_trajectories = get_full_mujoco_trajectories(
            env, expert, num_traj ,num_sa_pairs, gamma, device
        )

        # Unpack the trajectories into separate lists
        states, actions, rewards, next_states, values = zip(*expert_trajectories)

        # Concatenate lists into tensors
        states = torch.stack(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        next_states = torch.stack(next_states)
        values = torch.stack(values)

    # Compute TD(0) targets
    with torch.no_grad():
        value_next_states = value_network(next_states).reshape(-1)
        td_target = rewards + gamma * value_next_states
        v_t = value_network(states).reshape(-1)
        F_t = v_t - td_target

    value_states = value_network(states).reshape(-1)
    # init_loss = F.mse_loss(value_states, td_target)
    init_loss = F.mse_loss(value_states, v_t - step_size * F_t)
    # print(f"Initial loss: {init_loss.item():.4f}")
    alpha = 1
    step = 0
    # for step in range(inner_steps):
    while alpha > 0.5 and step <= inner_steps:
    # while step <= inner_steps:
        # Get current value estimates
        value_states = value_network(states).reshape(-1)

        # Compute loss
        # loss = F.mse_loss(value_states, td_target)
        loss = F.mse_loss(value_states, v_t - step_size * F_t)

        # Perform backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler_type is not None:
            scheduler.step()

        alpha = loss.item()/init_loss.item()
        step += 1
        # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
        # print(f"Inner Step {step + 1}/{inner_steps}, Loss: {loss.item():.4f}")

    # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
    return init_loss.item() / (step_size**2), alpha, step


def batch_td0_update_with_buffer(
    env: gym.Env,
    expert: torch.nn.Module,
    value_network: torch.nn.Module,
    optimizer_type: str,
    learning_rate: float,
    buffer: deque,
    num_sa_pairs: int,
    step_size: float,
    horizon: int,
    num_traj: int,
    device: torch.device,
    scheduler_type: str = None,
    loader = None,
    gamma: float = 0.99,
    inner_steps: int = 10,
):
    value_network.train()
    if optimizer_type == 'AdamW':
        optimizer = AdamW(value_network.parameters(), lr=learning_rate)
    elif optimizer_type == 'SGD':
        optimizer = SGD(value_network.parameters(), lr=learning_rate)
    if scheduler_type == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=inner_steps, eta_min=learning_rate)
    elif scheduler_type == 'exp':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    else:
        scheduler_type = None

    if loader is not None:
        batch = next(loader)
        states, actions, rewards, next_states, values = batch
    else:    
        # Collect trajectories
        expert_trajectories = get_full_mujoco_trajectories(
            env, expert, num_traj ,num_sa_pairs, gamma, device
        )

        # Unpack the trajectories into separate lists
        states, actions, rewards, next_states, values = zip(*expert_trajectories)

        # Concatenate lists into tensors
        states = torch.stack(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        next_states = torch.stack(next_states)
        values = torch.stack(values)

    # Compute TD(0) targets
    v_t_buffers = []
    with torch.no_grad():
        v_t = value_network(states).reshape(-1)
        v_t_next = value_network(next_states).reshape(-1)
        td_target = rewards + gamma * v_t_next
        F_v_t = v_t - td_target
        for trajectory in buffer:
            buffer_states, _, _, _, _ = zip(*trajectory)
            buffer_states = torch.stack(buffer_states)
            v_t_buffers.append(value_network(buffer_states).reshape(-1))
    
    init_loss = F.mse_loss(v_t, td_target)

    # for step, trajectory in enumerate(buffer):
    buffer_size = len(v_t_buffers)
    for i in range(inner_steps):
        trajectory = buffer[i % buffer_size]
        buffer_states, _, _, _, _ = zip(*trajectory)
        buffer_states = torch.stack(buffer_states)
        
        v = value_network(states).reshape(-1)
        v_buffer = value_network(buffer_states).reshape(-1)
        v_t_buffer = v_t_buffers[i % buffer_size]
        
        td_term = step_size * torch.dot(F_v_t, v - v_t)/v.shape[0]
        reg_term = 0.5 * torch.norm(v_buffer - v_t_buffer, p=2) ** 2 / v_buffer.shape[0]
        surrogate_loss = td_term + reg_term
                        
        # Perform the optimization step
        optimizer.zero_grad()
        surrogate_loss.backward()
        # print(surrogate_loss.item())
        optimizer.step()
        if scheduler_type is not None:
            scheduler.step()
        loss = F.mse_loss(v, td_target) 
        alpha = loss.item() / init_loss.item()
    
    # Append the new batch into the buffer and remove the oldest if necessary
    buffer.append(list(zip(states, actions, rewards, next_states, values)))
    if len(buffer) > inner_steps:
        buffer.popleft()
    
    return init_loss.item(), alpha


def batch_td0_update_with_buffer_loader(
    env: gym.Env,
    expert: torch.nn.Module,
    value_network: torch.nn.Module,
    optimizer_type: str,
    learning_rate: float,
    buffer: iter,  # Use the buffer DataLoader iterator
    num_sa_pairs: int,
    step_size: float,
    horizon: int,
    num_traj: int,
    device: torch.device,
    loader = None,
    scheduler_type: str = None,
    gamma: float = 0.99,
    inner_steps: int = 10,
):
    value_network.train()
    if optimizer_type == 'AdamW':
        optimizer = AdamW(value_network.parameters(), lr=learning_rate)
    elif optimizer_type == 'SGD':
        optimizer = SGD(value_network.parameters(), lr=learning_rate)
    if scheduler_type == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=inner_steps, eta_min=learning_rate)
    elif scheduler_type == 'exp':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    else:
        scheduler_type = None

    if loader is not None:
        batch = next(loader)
        states, actions, rewards, next_states, values = batch
    else:    
        # Collect trajectories
        expert_trajectories = get_full_mujoco_trajectories(
            env, expert, num_traj ,num_sa_pairs, gamma, device
        )

        # Unpack the trajectories into separate lists
        states, actions, rewards, next_states, values = zip(*expert_trajectories)

        # Concatenate lists into tensors
        states = torch.stack(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)
        next_states = torch.stack(next_states)
        values = torch.stack(values)

    sv_t_buffer = list()
    # Compute TD(0) targets for the current batch of data
    with torch.no_grad():
        v_t = value_network(states).reshape(-1)
        v_t_next = value_network(next_states).reshape(-1)
        td_target = rewards + gamma * v_t_next
        F_v_t = v_t - td_target

        for step in range(inner_steps):
            buffer_batch = next(buffer)
            buffer_states, _, _, _, _ = buffer_batch
            v_t_buffer = value_network(buffer_states).reshape(-1)
            sv_t_buffer.append((buffer_states,v_t_buffer))

    # Compute the initial loss (for logging purposes)
    init_loss = F.mse_loss(v_t, td_target)

    # Inner optimization loop (using buffer samples)
    for step in range(inner_steps):
        buffer_state,v_t_buffer = sv_t_buffer[step]

        v = value_network(states).reshape(-1)
        # Compute value predictions for the buffer
        v_buffer = value_network(buffer_state).reshape(-1)

        # Compute the TD(0) surrogate loss terms
        td_term = step_size * torch.dot(F_v_t, v - v_t)/v.shape[0]
        reg_term = 0.5 * torch.norm(v_buffer - v_t_buffer, p=2) ** 2 / v_buffer.shape[0]
        surrogate_loss = td_term + reg_term

        # Perform the optimization step
        optimizer.zero_grad()
        surrogate_loss.backward()
        optimizer.step()
        if scheduler_type is not None:
            scheduler.step()
        loss = F.mse_loss(v, td_target) 
        alpha = loss.item() / init_loss.item()
        # print(F.mse_loss(v, td_target).item())
        # print(f"Alpha value: {loss.item()/init_loss.item():.4f}")
        

    # Return the initial loss for logging
    return init_loss.item(), alpha

if __name__ == "__main__":
    main()
