# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import os
import statistics
import time
import torch
from collections import deque

from typing import Union
import math

import rsl_rl
from rsl_rl.algorithms import PPO
from rsl_rl.env import VecEnv
from rsl_rl.modules import ActorCriticBase, ActorCritic, ActorCriticRecurrent, EmpiricalNormalization, ExtendableActorCritic, HierarchicalActorCritic, ActorCriticForAnalysis
# from rsl_rl.modules import *
from rsl_rl.utils import store_code_state
import gymnasium as gym
from isaaclab.envs import  ManagerBasedRLEnvCfg
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
from tqdm import tqdm
from einops import rearrange, repeat
import h5py


def index_sim_state(sim_state: dict[str, dict[str, dict[str, torch.Tensor]]], env_idx: int) -> dict[str, dict[str, torch.Tensor]]:
    """
    Index the sim state of the i-th environment.
    Args:
        sim_state (dict): The sim state of all environments.
        env_idx (int): The index of the environment.
    Returns:
        dict: The sim state of the i-th environment.
    """
    sim_state_i = {}
    for key, value in sim_state.items():
        new_value = {}
        for sub_key, sub_value in value.items():
            new_sub_value = {}
            for sub_sub_key, sub_sub_value in sub_value.items():
                if isinstance(sub_sub_value, torch.Tensor):
                    new_sub_value[sub_sub_key] = sub_sub_value[env_idx]
                else:
                    raise ValueError(f"Unsupported type {type(sub_sub_value)} for key {sub_sub_key}.")
            new_value[sub_key] = new_sub_value
        sim_state_i[key] = new_value
    return sim_state_i



class DynamicsAnalysisRunner:
    """This dynamics analysis runner is only used to analyze how much dynamics information is contained in the trained 
    policy (maybe critic too).
    
    Key args:
        model_path (str): path to the trained model.
        env (VecEnv): environments to use for data collection.
        save_path (str): path to save the data.

    """

    def __init__(self, task: str, env_cfg:ManagerBasedRLEnvCfg, policy_cfg: dict, num_subsets: int, subset_size: int, 
                 data_dir: str, h5_file_name: str, device="cuda"):
        self.policy_cfg = policy_cfg
        self.device = device
        self.num_subsets = num_subsets
        self.subset_size = subset_size

        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        self.file_name = os.path.join(data_dir, h5_file_name)

        env_cfg.scene.num_envs = num_subsets*subset_size
        self.env = RslRlVecEnvWrapper(gym.make(task, cfg=env_cfg, render_mode=None))
           
        # resolve dimensions of observations
        obs, extras = self.env.get_observations()
        num_obs = obs.shape[1]
        if "critic" in extras["observations"]:
            num_critic_obs = extras["observations"]["critic"].shape[1]
        else:
            num_critic_obs = num_obs
        actor_critic_class = eval(self.policy_cfg.pop("class_name"))  # ActorCritic
        self.actor_critic: ActorCriticForAnalysis = actor_critic_class(
            num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg
        ).to(self.device)

        self.sample_counter = 0


    def run(self, num_subsets_to_save: int, interval_preparatory_iterations: int = 100, init_preparatory_iterations: Union[int|None] = None):  # noqa: C901
        """
        Run the data collection process.

        This function collects data from the environment and saves it to a file. 
        Firstly, randomize the episode lengths and run preparatory iterations to make the environments are in diverse 
        states. Then, duplicate each environment by subset_size, and sample i.i.d actions and step them respectively. 
        One data sample consists of the tuple (s_t, a_t, s_t+1). 
        if num_subsets_to_save > num_envs, rerun the data collection process until num_subsets_to_save is reached.
        Args:
            num_preparatory_iterations (int): Number of preparatory iterations to run.
            num_subsets_to_save (int): Number of subsets to save.
            subset_size (int): Size of each subset.        
        """

        # start learning
        obs_t, extras = self.env.get_observations()
        # critic_obs = extras["observations"].get("critic", obs_t)
        # obs_t, critic_obs = obs_t.to(self.device), critic_obs.to(self.device)

        tot_loops = math.ceil(num_subsets_to_save / self.num_subsets)

        with torch.inference_mode():
            # reset the environment
            self.env.reset()
            # randomize initial episode lengths, so we have diverse states earlier 
            self.env.episode_length_buf = torch.randint_like(
                self.env.episode_length_buf, high=int(self.env.max_episode_length)
            )
            # run preparatory iterations
            if init_preparatory_iterations is None:
                init_preparatory_iterations = self.env.max_episode_length
            
            # run initial preparatory iterations to make the environments are in diverse states
            for _ in range(init_preparatory_iterations):
                # Sample actions from policy
                actions = self.actor_critic.act(obs_t)
                # Step environment
                # obs_t, rewards, dones, infos = self.env.step(actions.to(self.env.device))
                obs_t, _, _, _ = self.env.step(actions.to(self.env.device))

            # to make it compatible with logic in following for loop
            obs_t_plus_1 = obs_t

            for n_loop in tqdm(range(tot_loops), desc="Data collection loop", total=tot_loops):
                obs_t = obs_t_plus_1
                # run some interval iterations to make the environments are in diverse states
                for _ in range(interval_preparatory_iterations):
                    # Sample actions from policy
                    actions = self.actor_critic.act(obs_t)
                    # Step environment
                    # obs_t, rewards, dones, infos = self.env.step(actions.to(self.env.device))
                    obs_t, _, _, _ = self.env.step(actions.to(self.env.device))

                # duplicate environments: this makes the envs in the same subset share the same sim state 
                self.duplicate_envs()

                # approach 1: get new observation and run the policy. but turned out that the reset function clears 
                # action history, so the new observation gets all-zero action history. thus not satisfying.
                # obs_n, _ = self.env.get_observations()
                # act_t = self.actor_critic.act(obs_n)

                # approach 2: run policy with old observation, then sample multiple times
                self.actor_critic.update_distribution(obs_t[0::self.subset_size])
                act_t = rearrange(torch.stack([self.actor_critic.distribution.sample() for i in range(self.subset_size)], dim=1), 'b n c -> (b n) c')
                
                # only for debugging: see if the states are the same if applying the same action
                # obs_t, extras = self.env.get_observations()
                # act_t = repeat(self.actor_critic.distribution.sample(), 'b c -> (b n) c', n=self.subset_size)
                
                # step duplicated environments
                obs_t_plus_1, rewards, dones, infos = self.env.step(act_t.to(self.env.device))
                
                print(obs_t)
                
                # save the data
                self.save_data(obs_t[0::self.subset_size], act_t, obs_t_plus_1)

        # close the environments
        self.env.close()

    def duplicate_envs(self,):
        """
        Duplicate the environments by writing the envs in the same subset with the sim state of the first environment
        in the subset. We treat every consecutive subset_size environments as a subset. 
        """
        # TODO
        for i in range(0, self.num_subsets*self.subset_size, self.num_subsets):
            # get the sim state of the i-th environment
            sim_state = self.env.unwrapped.scene.get_state(is_relative=False)
            sim_state_i = index_sim_state(sim_state, i)
            # set the sim state to the i-th environment
            self.env.unwrapped.reset_to(sim_state_i, env_ids=torch.tensor(range(i+1, i+self.subset_size), dtype=torch.int32).to(self.env.device))   
        pass


    def save_data(self, obs_t_per_subset, act_t, obs_t_plus_1):
        """
        Append the data to the h5 file on the disk.

        Args:
            obs_t (torch.Tensor) [num_subsets, obs_dim]: The observations at time t.
            act_t (torch.Tensor) [num_subsets*subset_size, act_dim]: The actions at time t.
            obs_t_plus_1 (torch.Tensor) [num_subsets*subset_size, obs_dim]: The observations at time t+1.
        """

        act_t = rearrange(act_t, '(b n) c -> b n c', b=self.num_subsets, n=self.subset_size)
        obs_t_plus_1 = rearrange(obs_t_plus_1, '(b n) c -> b n c', b=self.num_subsets, n=self.subset_size)

        with h5py.File(self.file_name, "a") as f:
            # create the dataset if it doesn't exist
            for i in range(self.num_subsets):
                grp = f.create_group(f'subset_{self.sample_counter}')
            
                # Store s_t (e.g., state representation)
                s_t = obs_t_per_subset[i].cpu().numpy() # [obs_dim]
                grp.create_dataset('s_t', data=s_t)

                # Store a_t as a 2D array of shape (subset_size, action_dim)
                a_t = act_t[i].cpu().numpy()
                grp.create_dataset('a_t', data=a_t)

                # Store s_t+1 as a 2D array of shape (subset_size, obs_dim)
                s_t_plus_1 = obs_t_plus_1[i].cpu().numpy()
                grp.create_dataset('s_t_plus_1', data=s_t_plus_1)

                self.sample_counter += 1

    def load(self, path: str):
        """
        Load a previously saved model.
        Args:
            path (str): Path to the saved model.
        """

        return self.actor_critic.load_trunk(path)
    


