"""
This file implements the expert trajectory sampling algorithm for GAIL.
Specifically, we take the code from 
https://github.com/hcnoh/gail-pytorch/blob/main/models/gail.py,
updating it for better performance.

TODO: investigate vectorised environments.
TODO: the expert is only sampling num_sa_pairs over a single trajectory
    potentially, I don't think this is intended.. it makes the expert
    dataset quite small and might make the GAIL optimisation a bit more
    difficult than it needs to be.. might just need more data.
"""

import gymnasium as gym
import numpy as np
import os
import torch

import matplotlib.pyplot as plt
import matplotlib.animation as animation

from PPO import utils
# from minatar import Environment
# from tianshou.utils.net.common import Net, ActorCritic
from torch.distributions import Distribution, Independent, Normal
# from src.agent_networks.a2c_acktr_ppo_gail import map_action


import random
seed = 6666
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)  # if you are using multi-GPU
from torch.utils.data import Dataset, DataLoader


# Custom Dataset class to handle trajectories
class MujocoTrajectoryDataset(Dataset):
    def __init__(self, trajectories):
        self.trajectories = trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        return self.trajectories[idx]


# Create DataLoader from pre-sampled trajectories
def create_dataloader_from_trajectories(trajectories, batch_size=64, shuffle=True):
    dataset = MujocoTrajectoryDataset(trajectories)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
    loc, scale = loc_scale
    return Independent(Normal(loc, scale), 1)


@torch.no_grad() # disable gradient tracking in this function.
def get_expert_trajectories(
        env: gym.Env,
        expert: torch.nn.Module,
        num_sa_pairs: int,
        horizon: int,
        device: torch.DeviceObjType,
        render_gif: bool = False,
        gif_path: str = None,
    ) -> tuple[float, torch.Tensor, torch.Tensor]:
    """
    This function queries the expert model to generate trajectories of
    state-action pairs.

    Parameters:
        env (`gym.Env`): The environment we're using to sample trajectories.
        expert (`torch.nn.Module`): The expert model.
        num_sa_pairs (`int`): The number of state-action pairs to sample.
        horizon (`int`): The max number of steps per trajectory.
        device (`torch.device`): The device to put all data on.
        render_gif (`bool`): Whether to output a gif of the expert actions.
        gif_path (`str`): Where to save the gif.

    Returns:
        `tuple[float, torch.Tensor, torch.Tensor]`: the expert reward mean,
            state and action tensors.
    """
    expert.eval()

    # Variables to track trajectory generation.
    num_gen_traj = 0
    expert_obs = []
    expert_actions = []
    expert_reward = []
    while len(expert_obs) < num_sa_pairs:
        # Run trajectories until we have a sufficient amount of data.
        gif_frames = []
        obs, _ = env.reset()
        if(render_gif):
            gif_frames.append(env.render())
        episode_done = False
        episode_reward = 0
        num_steps = 0
        while not episode_done:
            # Run an episode to generate a trajectory.
            obs = torch.tensor(obs, device=device).float()
            expert_action = expert.act(obs)
            expert_obs.append(obs)
            expert_actions.append(
                torch.tensor(expert_action, device=device),
            )

            obs, reward, episode_done, _, _ = env.step(expert_action)
            episode_reward += reward

            if(render_gif):
                gif_frames.append(env.render())
            
            if(horizon is not None and 
                num_steps >= horizon or
                len(expert_obs) >= num_sa_pairs):
                # We have hit the max episode length or generated all the data
                # we needed to generate, finish the episode.
                episode_done = True
            
            # Loop control
            num_steps += 1
        expert_reward.append(episode_reward)
        
        if(render_gif):
            # Create the figure and axes objects
            fig, ax = plt.subplots()

            # Set the initial image
            im = ax.imshow(gif_frames[0], animated=True)

            def update(i):
                im.set_array(gif_frames[i])
                return im,

            # Create the animation object
            animation_fig = animation.FuncAnimation(fig, update, frames=len(gif_frames), interval=40, blit=True,repeat_delay=10,)

            # Show the animation
            plt.show()
            plt.title(f"Expert Trajectory #{num_gen_traj+1}")

            animation_fig.save(os.path.join(gif_path, f"trajectory_{num_gen_traj+1}.gif"))
        num_gen_traj += 1
    
    # Get relevant logging information
    exp_rwd_mean = np.mean(expert_reward).item()

    # Convert data to tensors.
    t_obs = torch.stack(expert_obs, dim=0)
    t_acts = torch.stack(expert_actions).float()

    return exp_rwd_mean, t_obs, t_acts

@torch.no_grad() # disable gradient tracking in this function.
# def get_full_expert_trajectories(
#         env: gym.Env,
#         expert: torch.nn.Module,
#         num_sa_pairs: int,
#         horizon: int,
#         device: torch.DeviceObjType,
#         render_gif: bool = False,
#         gif_path: str = None,
#     ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
#     """
#     This function queries the expert model to generate full trajectories of
#     state-action pairs for facillitating TD learnig, we will collect reward,
#     state, action and next states.

#     Parameters:
#         env (`gym.Env`): The environment we're using to sample trajectories.
#         expert (`torch.nn.Module`): The expert model.
#         num_sa_pairs (`int`): The number of state-action pairs to sample.
#         horizon (`int`): The max number of steps per trajectory.
#         device (`torch.device`): The device to put all data on.
#         render_gif (`bool`): Whether to output a gif of the expert actions.
#         gif_path (`str`): Where to save the gif.

#     Returns:
#         `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]`: the expert reward,
#             state, action and next state tensors.
#     """
#     expert.eval()

#     # Variables to track trajectory generation.
#     num_gen_traj = 0
#     expert_obs = []
#     expert_actions = []
#     expert_reward = []
#     next_obs = []
#     while len(expert_obs) < num_sa_pairs:
#         # Run trajectories until we have a sufficient amount of data.
#         gif_frames = []
#         obs, _ = env.reset()
#         if(render_gif):
#             gif_frames.append(env.render())
#         episode_done = False
#         num_steps = 0
#         while not episode_done:
#             # Run an episode to generate a trajectory.
#             obs = torch.tensor(obs, device=device).float()
#             expert_action = expert.act(obs)
#             expert_obs.append(obs)
#             expert_actions.append(
#                 torch.tensor(expert_action, device=device),
#             )

#             next_ob, reward, episode_done, _, _ = env.step(expert_action)
#             expert_reward.append(torch.tensor(reward, device=device).float())
#             next_obs.append(torch.tensor(next_ob, device=device).float())

#             obs = next_ob  # Update observation

#             if(render_gif):
#                 gif_frames.append(env.render())
            
#             if(horizon is not None and 
#                 num_steps >= horizon or
#                 len(expert_obs) >= num_sa_pairs):
#                 # We have hit the max episode length or generated all the data
#                 # we needed to generate, finish the episode.
#                 episode_done = True
            
#             # Loop control
#             num_steps += 1
#         # expert_reward.append(episode_reward)
        
#         if(render_gif):
#             # Create the figure and axes objects
#             fig, ax = plt.subplots()

#             # Set the initial image
#             im = ax.imshow(gif_frames[0], animated=True)

#             def update(i):
#                 im.set_array(gif_frames[i])
#                 return im,

#             # Create the animation object
#             animation_fig = animation.FuncAnimation(fig, update, frames=len(gif_frames), interval=40, blit=True,repeat_delay=10,)

#             # Show the animation
#             plt.show()
#             plt.title(f"Expert Trajectory #{num_gen_traj+1}")

#             animation_fig.save(os.path.join(gif_path, f"trajectory_{num_gen_traj+1}.gif"))
#         num_gen_traj += 1

#     # Convert data to tensors.
#     t_obs = torch.stack(expert_obs, dim=0)
#     t_acts = torch.stack(expert_actions).float()
#     t_rews = torch.stack(expert_reward).float()
#     t_next_obs = torch.stack(next_obs, dim=0)

#     return t_obs, t_acts, t_rews, t_next_obs


@torch.no_grad()  # disable gradient tracking in this function.
def get_full_expert_trajectories(
    env: gym.Env,
    expert: torch.nn.Module,
    num_trajectories: int,
    max_len: int,
    gamma: float,
    device: torch.DeviceObjType,
    render_gif: bool = False,
    gif_path: str = None,
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    This function queries the expert model to generate full trajectories of
    state-action pairs for facilitating TD learning, and collects reward,
    state, action, next state, and cumulative value.

    Parameters:
        env (`gym.Env`): The environment we're using to sample trajectories.
        expert (`torch.nn.Module`): The expert model.
        num_trajectories (`int`): The number of trajectories to sample.
        max_len (`int`): The max number of steps per trajectory.
        device (`torch.device`): The device to put all data on.
        render_gif (`bool`): Whether to output a gif of the expert actions.
        gif_path (`str`): Where to save the gif.

    Returns:
        `list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]`: A list of tuples containing 
            (state, action, reward, next state, cumulative value) for each trajectory.
    """
    expert.eval()

    trajectories = []

    for traj_idx in range(num_trajectories):
        obs, _ = env.reset()
        obs = torch.tensor(obs, device=device).float()
        episode_done = False
        num_steps = 0
        trajectory = []

        if render_gif:
            gif_frames = [env.render()]

        rewards = []
        actions = []
        states = [obs]

        while not episode_done and num_steps < max_len:
            expert_action = expert.act(obs)
            action_tensor = torch.tensor(expert_action, device=device).float()
            actions.append(action_tensor)

            next_ob, reward, episode_done, _, _ = env.step(expert_action)
            reward_tensor = torch.tensor(reward, device=device).float()
            rewards.append(reward_tensor)

            next_ob = torch.tensor(next_ob, device=device).float()
            states.append(next_ob)

            obs = next_ob  # Update observation

            if render_gif:
                gif_frames.append(env.render())

            num_steps += 1

        # Compute the cumulative value for the trajectory using Monte Carlo estimation
        cumulative_values = []
        g = 0.0
        for reward in reversed(rewards):
            g = reward + gamma*g  # You can modify this with a discount factor if necessary
            cumulative_values.insert(0, g)

        # Store the trajectory data
        for i in range(num_steps):
            trajectory.append((states[i], actions[i], rewards[i], states[i + 1], cumulative_values[i]))

        trajectories.append(trajectory)

        if render_gif:
            # Create the figure and axes objects
            fig, ax = plt.subplots()

            # Set the initial image
            im = ax.imshow(gif_frames[0], animated=True)

            def update(i):
                im.set_array(gif_frames[i])
                return im,

            # Create the animation object
            animation_fig = animation.FuncAnimation(fig, update, frames=len(gif_frames), interval=40, blit=True, repeat_delay=10,)

            # Show the animation
            plt.show()
            plt.title(f"Expert Trajectory #{traj_idx + 1}")

            animation_fig.save(os.path.join(gif_path, f"trajectory_{traj_idx + 1}.gif"))

    return [state for trajectory in trajectories for state in trajectory]



@torch.no_grad()  # disable gradient tracking in this function.
def get_full_minatar_trajectories(
    env: gym.Env,
    expert: torch.nn.Module,
    num_trajectories: int,
    max_len: int,
    gamma: float,
    device: torch.DeviceObjType,
    render_gif: bool = False,
    gif_path: str = None,
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    This function queries the expert model to generate full trajectories of
    state-action pairs for facilitating TD learning, and collects reward,
    state, action, next state, and cumulative value.

    Parameters:
        env (`Environment`): The MinAtar environment we're using to sample trajectories.
        expert (`torch.nn.Module`): The expert model.
        num_trajectories (`int`): The number of trajectories to sample.
        max_len (`int`): The max number of steps per trajectory.
        device (`torch.device`): The device to put all data on.
        render_gif (`bool`): Whether to output a gif of the expert actions.
        gif_path (`str`): Where to save the gif.

    Returns:
        `list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]`: A list of tuples containing 
            (state, action, reward, next state, cumulative value) for each trajectory.
    """
    expert.eval()

    trajectories = []

    for traj_idx in range(num_trajectories):
        env.reset()
        obs = torch.tensor(env.state(), device=device).permute(2, 0, 1).unsqueeze(0).float()
        episode_done = False
        num_steps = 0
        trajectory = []

        if render_gif:
            gif_frames = [env.render()]

        rewards = []
        actions = []
        states = [obs]

        while not episode_done and num_steps < max_len:
            action = torch.multinomial(expert(obs)[0], 1)[0].item()
            actions.append(torch.tensor(action, device=device).float())

            reward, episode_done = env.act(action)
            reward_tensor = torch.tensor(reward, device=device).float()
            rewards.append(reward_tensor)

            next_obs = torch.tensor(env.state(), device=device).permute(2, 0, 1).unsqueeze(0).float()
            states.append(next_obs)

            obs = next_obs  # Update observation

            if render_gif:
                gif_frames.append(env.render())

            num_steps += 1

        # Compute the cumulative value for the trajectory using Monte Carlo estimation
        cumulative_values = []
        g = 0.0
        for reward in reversed(rewards):
            g = reward + gamma * g
            cumulative_values.insert(0, g)

        # Store the trajectory data
        for i in range(num_steps):
            trajectory.append((states[i], actions[i], rewards[i], states[i + 1], cumulative_values[i]))

        trajectories.append(trajectory)

        if render_gif:
            # Create the figure and axes objects
            fig, ax = plt.subplots()

            # Set the initial image
            im = ax.imshow(gif_frames[0], animated=True)

            def update(i):
                im.set_array(gif_frames[i])
                return im,

            # Create the animation object
            animation_fig = animation.FuncAnimation(fig, update, frames=len(gif_frames), interval=40, blit=True, repeat_delay=10,)

            # Show the animation
            plt.show()
            plt.title(f"Expert Trajectory #{traj_idx + 1}")

            animation_fig.save(os.path.join(gif_path, f"trajectory_{traj_idx + 1}.gif"))

    return [state for trajectory in trajectories for state in trajectory]


@torch.no_grad()  # disable gradient tracking in this function.
def get_full_mujoco_trajectories(
    env: gym.Env,
    expert: torch.nn.Module,
    num_trajectories: int,
    max_len: int,
    gamma: float,
    device: torch.DeviceObjType,
    render_gif: bool = False,
    gif_path: str = None,
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    This function queries the expert model to generate full trajectories of
    state-action pairs for facilitating TD learning, and collects reward,
    state, action, next state, and cumulative value.

    Parameters:
        env (`Environment`): The MinAtar environment we're using to sample trajectories.
        expert (`torch.nn.Module`): The expert model.
        num_trajectories (`int`): The number of trajectories to sample.
        max_len (`int`): The max number of steps per trajectory.
        device (`torch.device`): The device to put all data on.
        render_gif (`bool`): Whether to output a gif of the expert actions.
        gif_path (`str`): Where to save the gif.

    Returns:
        `list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]`: A list of tuples containing 
            (state, action, reward, next state, cumulative value) for each trajectory.
    """
    expert.actor.eval()

    trajectories = []

    for traj_idx in range(num_trajectories):
        obs, _ = env.reset()
        obs = torch.tensor(obs, device=device).float().unsqueeze(0)
        max_action = env.action_space.high[0]
        episode_done = False
        num_steps = 0
        trajectory = []

        if render_gif:
            gif_frames = [env.render()]

        rewards = []
        actions = []
        states = [obs]

        while not episode_done and num_steps < max_len:
            action, prob = expert.choose_action(obs)
            act = utils.action_adapter(action, max_action)
            next_obs, reward, episode_done, _, _ = env.step(act)

            actions.append(torch.tensor(act, device=device).float())
            reward_tensor = torch.tensor(reward, device=device).float()
            rewards.append(reward_tensor)

            next_obs = torch.tensor(next_obs, device=device).float().unsqueeze(0)
            states.append(next_obs)


            obs = next_obs  # Update observation

            if render_gif:
                gif_frames.append(env.render())

            num_steps += 1

        # Compute the cumulative value for the trajectory using Monte Carlo estimation
        cumulative_values = []
        g = 0.0
        for reward in reversed(rewards):
            g = reward + gamma * g
            cumulative_values.insert(0, g)

        # Store the trajectory data
        for i in range(num_steps):
            trajectory.append((states[i], actions[i], rewards[i], states[i + 1], cumulative_values[i]))

        trajectories.append(trajectory)

        if render_gif:
            # Create the figure and axes objects
            fig, ax = plt.subplots()

            # Set the initial image
            im = ax.imshow(gif_frames[0], animated=True)

            def update(i):
                im.set_array(gif_frames[i])
                return im,

            # Create the animation object
            animation_fig = animation.FuncAnimation(fig, update, frames=len(gif_frames), interval=40, blit=True, repeat_delay=10,)

            # Show the animation
            plt.show()
            plt.title(f"Expert Trajectory #{traj_idx + 1}")

            animation_fig.save(os.path.join(gif_path, f"trajectory_{traj_idx + 1}.gif"))

    return [state for trajectory in trajectories for state in trajectory]


@torch.no_grad()
def get_sampled_state_values(
    env: gym.Env,
    expert: torch.nn.Module,
    num_trajectories: int,
    max_len: int,
    gamma: float,
    device: torch.DeviceObjType,
    num_samples: int,
    num_trajs_per_sample: int,
) -> list[tuple[torch.Tensor, float]]:
    """
    This function generates full trajectories, samples random states, and then estimates the value of those states
    by starting multiple trajectories from each sampled state.

    Parameters:
        env (`Environment`): The environment used to sample trajectories.
        AC (`ActorCritic`): The Actor-Critic model.
        num_trajectories (`int`): The number of full trajectories to generate.
        max_len (`int`): The maximum number of steps per trajectory.
        device (`torch.device`): The device for computation.
        num_samples (`int`): Number of states to sample from the full trajectory.
        num_trajs_per_sample (`int`): Number of trajectories to generate from each sampled state.

    Returns:
        `list[tuple[torch.Tensor, float]]`: A list of tuples containing the sampled state and its Monte Carlo value estimate.
    """
    expert.actor.eval()
    sampled_state_values = []

    for traj_idx in range(num_trajectories):
        obs, _ = env.reset()
        obs = torch.tensor(obs, device=device).float().unsqueeze(0)
        max_action = env.action_space.high[0]
        episode_done = False
        num_steps = 0
        states_info = []
        states = [obs]
        rewards = []
        while not episode_done and num_steps < max_len:
            action, prob = expert.choose_action(obs)
            act = utils.action_adapter(action, max_action)
            next_obs, reward, episode_done, _, _ = env.step(act)

            next_obs = torch.tensor(next_obs, device=device).float().unsqueeze(0)
            states.append(next_obs)
            rewards.append(reward)
            
            # if we use halfcheetah-v4 env
            # qpos = env.unwrapped.state_vector()[:9]
            # qvel = env.unwrapped.state_vector()[9:]

            #if we use ant-v4 env
            qpos = env.unwrapped.state_vector()[:15]
            qvel = env.unwrapped.state_vector()[15:]
            # Save the state info (qpos and qvel) for resetting the environment later
            states_info.append((qpos,qvel))

            obs = next_obs  # Update observation
            num_steps += 1

        # Sample random states from the full trajectory
        sampled_indices = random.sample(range(num_steps), min(num_samples, num_steps))

        for idx in sampled_indices:
            sampled_state = states[idx]
            sampled_state_info = states_info[idx]
            monte_carlo_estimates = []

            # Generate multiple trajectories from the sampled state
            for _ in range(num_trajs_per_sample):
                # Reset the environment to the sampled state
                env.unwrapped.set_state(sampled_state_info[0], sampled_state_info[1])
                obs = sampled_state
                episode_done = False
                rewards = []
                num_steps = 0
                while not episode_done and num_steps < max_len:
                    action, prob = expert.choose_action(obs)
                    act = utils.action_adapter(action, max_action)
                    next_obs, reward, episode_done, _, info = env.step(act)
                    reward_tensor = torch.tensor(reward, device=device).float()
                    rewards.append(reward_tensor)

                    obs = torch.tensor(next_obs, device=device).float().unsqueeze(0)
                    num_steps +=1

                # Compute the cumulative value using Monte Carlo estimation
                g = 0.0
                for reward in reversed(rewards):
                    g = reward + gamma * g

                monte_carlo_estimates.append(g.item())

            # Store the sampled state and the average Monte Carlo estimate
            avg_value_estimate = sum(monte_carlo_estimates) / len(monte_carlo_estimates)
            sampled_state_values.append((sampled_state, avg_value_estimate))

    return sampled_state_values