# OS params
import os
import warnings
warnings.filterwarnings("ignore")

os.environ['MUJOCO_GL']='egl' # for headless
os.environ["HYDRA_FULL_ERROR"] = "1"
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import hydra
import warnings
import rootutils
import functools

# Configs & Printing
import wandb
from omegaconf import DictConfig
ROOT = rootutils.setup_root(search_from=__file__, indicator=[".git", "pyproject.toml"],
                            pythonpath=True, cwd=True)

# Libs
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from tqdm.auto import tqdm
from jaxrl_m.wandb import setup_wandb
import matplotlib.pyplot as plt

import d4rl
from src.agents import icvf
from src.agents.icvf import eval_ensemble_gotil, eval_ensemble_icvf_viz
from src.utils import record_video
from src import d4rl_utils, d4rl_ant, viz_utils
from src.gc_dataset import GCSDataset

# Utilities from root folder
from utils.ds_builder import setup_datasets
from utils.rich_utils import print_config_tree

@eqx.filter_jit
def get_gcvalue(agent, s, z):
    v1, v2 = eval_ensemble_gotil(agent.agent_icvf.value_learner.model, s, z)
    return (v1 + v2) / 2

def get_v(agent, observations):
    intents = eqx.filter_vmap(agent.sample_intentions, in_axes=(0, None))(observations, jax.random.PRNGKey(42))
    return get_gcvalue(agent, observations, intents)

@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None), out_axes=0)
def eval_ensemble(ensemble, s):
    return eqx.filter_vmap(ensemble)(s)

@eqx.filter_jit
def get_debug_statistics_icvf(agent, batch):
    def get_info(s, g, z):
        return eval_ensemble_icvf_viz(agent.value_learner.model, s, g, z)
    
    s = batch['observations']
    g = batch['icvf_goals']
    z = eval_ensemble(agent.value_learner.model.psi_net, batch['icvf_desired_goals'])[0]
    g = eval_ensemble(agent.value_learner.model.psi_net, batch['icvf_desired_goals'])[0]
    info_szz = get_info(s, z, z)        
    info_sgz = get_info(s, g, z)

    if 'phi' in info_sgz:
        stats = {
            'phi_norm': jnp.linalg.norm(info_sgz['phi'], axis=-1).mean(),
            'psi_norm': jnp.linalg.norm(info_sgz['psi'], axis=-1).mean(),
        }
    else:
        stats = {}

    stats.update({
        'v_szz': info_szz.mean(),
        'v_sgz': info_sgz.mean()
    })
    return stats

@eqx.filter_jit
def get_traj_v(agent, trajectory, seed):
    def get_v(s, g):
        v1, v2 = eval_ensemble_gotil(agent.agent_icvf.value_learner.model, s[None], g[None])
        return (v1 + v2) / 2
    observations = trajectory['observations']
    intents = eqx.filter_vmap(agent.actor_intents_learner.model)(observations).sample(seed=seed)
    all_values = jax.vmap(jax.vmap(get_v, in_axes=(None, 0)), in_axes=(0, None))(observations, intents)
    return {
        'dist_to_beginning': all_values[:, 0],
        'dist_to_end': all_values[:, -1],
        'dist_to_middle': all_values[:, all_values.shape[1] // 2],
    }
    
@eqx.filter_jit
def get_traj_v_icvf(agent, trajectory):
    def get_v(s, g):
        return eval_ensemble_icvf_viz(agent.expert_icvf.value_learner.model, s[None], g[None], g[None]).mean()
    
    observations = trajectory['observations']
    obs_intents = eval_ensemble(agent.expert_icvf.value_learner.model.psi_net, observations)[0]
    all_values = jax.vmap(jax.vmap(get_v, in_axes=(None, 0)), in_axes=(0, None))(observations, obs_intents)
    return {
        'dist_to_beginning': all_values[:, 0],
        'dist_to_end': all_values[:, -1],
        'dist_to_middle': all_values[:, all_values.shape[1] // 2],
    }

@hydra.main(version_base="1.4", config_path=str(ROOT/"configs"), config_name="entry.yaml")
def main(config: DictConfig):
    print_config_tree(config)
    setup_wandb(hyperparam_dict=dict(config),
                mode="offline",
                name=None)
    
    env, expert_ds, agent_ds, agent_mean_states, agent_std_states = setup_datasets(expert_env_name=config.expert_env_name,
                                          agent_env_name=config.agent_env_name, expert_num=config.num_expert_trajs,
                                          normalize_agent_states=config.normalize_states)
    
    gcsds_params = GCSDataset.get_default_config()
    gc_expert_dataset = GCSDataset(expert_ds, **gcsds_params)
    gc_agent_dataset = GCSDataset(agent_ds, **gcsds_params)

    expert_trajectory = gc_expert_dataset.dataset.dataset_dict['observations']
    if wandb.config.num_expert_trajs > 1:
        expert_trajectory = expert_trajectory.reshape(wandb.config.num_expert_trajs, -1, env.observation_space.shape[0]) # first arg - number of expert trajs
        
    if config.algo_name == "IQL":
        from src.agents.iql_flax.common import Batch
        from src.agents.iql_flax.learner import Learner
        from src.agents.iql_flax.evaluation import evaluate

        max_steps = wandb.config.max_steps
        batch_size = 256

        iql_agent = Learner(
                wandb.config.seed,
                env.observation_space.sample()[np.newaxis],
                env.action_space.sample()[np.newaxis],
                max_steps=max_steps,
                expectile=wandb.config.expectile,
                discount=wandb.config.discount,
                temperature=wandb.config.temperature)

        pbar = tqdm(range(max_steps))
        for i in pbar:
            sample = gc_agent_dataset.dataset.sample(batch_size=batch_size)
            batch = Batch(
                observations=sample["observations"],
                next_observations=sample["next_observations"],
                actions = sample["actions"],
                rewards = sample["rewards"], 
                masks= sample["masks"]
            )
            update_info = iql_agent.update(batch)
            update_info['adv'] = None
            
            if i % 50_000 == 0 and i > 0:
                eval_stats = evaluate(iql_agent, env, num_episodes=10)
                wandb.log({'Eval': eval_stats})
                print(eval_stats)
                pbar.set_postfix(update_info)
                
            if i % 3000 == 0:
                wandb.log({'Training/': update_info})
                pbar.set_postfix(update_info)
                
        print(f"Finished Training, Runnig evaluation")
        eval_stats = evaluate(iql_agent, env, num_episodes=config.num_eval_eps)
        print(f"Results on Evaluation: {eval_stats}")
    
    elif config.algo_name == "AILOT":
        from src.agents import icvf
        from src.agents.ailot import OTRewardsExpert
        from src.agents.iql_flax.common import Batch
        from src.agents.iql_flax.learner import Learner
        from src.agents.iql_flax.evaluation import evaluate
        
        icvf_model = icvf.create_eqx_learner(seed=42,
                                            observations=expert_ds.dataset_dict['observations'][0],
                                            hidden_dims=[256, 256],
                                            pretrained_folder=wandb.config.agent_env_name[:-3],
                                            load_pretrained_icvf=True)
        print(f"RELABELLING VIA AILOT")
        expert = OTRewardsExpert(expert_trajectory, icvf_model)
        rewards = expert.compute_rewards(gc_agent_dataset.dataset, gc_agent_dataset)

        from src.dataset import Dataset

        class ExpRewardsScaler:
            def init(self, rewards: np.ndarray):
                self.min = np.quantile(np.abs(rewards).reshape(-1), 0.0)
                self.max = np.quantile(np.abs(rewards).reshape(-1), 0.95)

            def scale(self, rewards: np.ndarray):
                # scaling of rewards
                return wandb.config.A_coefficient * np.exp(rewards / self.max)


        def get_subs(dataset: GCSDataset, add_steps: int):
            terminal_locs = dataset.terminal_locs
            indx = np.arange(dataset.dataset.dataset_dict['observations'].shape[0])
            final_state_indx = terminal_locs[np.searchsorted(terminal_locs, indx)] 
            way_indx = np.minimum(indx + add_steps, final_state_indx)
            subs = jax.tree_map(lambda arr: arr[way_indx], dataset.dataset.dataset_dict['observations'])
            return subs
        
        scaler = ExpRewardsScaler()
        scaler.init(rewards)
        scaled_rewards = scaler.scale(rewards).astype(np.float32)

        ## Apply iql scaling
        from utils.ds_builder import load_trajectories
        
        if "antmaze" in wandb.config.agent_env_name.split("-"):
            offline_traj = load_trajectories("antmaze-large-diverse", scaled_rewards)
        else:
            offline_traj = load_trajectories(wandb.config.agent_env_name, scaled_rewards)
            
        def compute_iql_reward_scale(trajs):
            """Rescale rewards based on max/min from the dataset.
            This is also used in the original IQL implementation.
            """
            trajs = trajs.copy()
            
            def compute_returns(tr):
                return sum([step[2] for step in tr])
            
            trajs.sort(key=compute_returns)
            reward_scale = 1000.0 / (compute_returns(trajs[-1]) - compute_returns(trajs[0]))
            return reward_scale

        subs_15 = get_subs(gc_agent_dataset, 15)
        subs_10 = get_subs(gc_agent_dataset, 10)
        subs_5 = get_subs(gc_agent_dataset, 5)

        ds = gc_agent_dataset.dataset.dataset_dict
        episode_starts, episode_ends, episode_returns = gc_agent_dataset.dataset._trajectory_boundaries_and_returns()
        data_with_ot_rewards = Dataset(
            {'observations': np.concatenate([ds['observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
            'next_observations': np.concatenate([ds['next_observations'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
            'actions': np.concatenate([ds['actions'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
            'rewards':scaled_rewards * compute_iql_reward_scale(offline_traj) - 2,
            'masks': 1.0 - np.concatenate([ds['dones'][episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min],
            'sub_observations_5': np.concatenate([subs_5[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], 
            'sub_observations_10':np.concatenate([subs_10[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32),#[scaled_rewards > r_min], 
            'sub_observations_15': np.concatenate([subs_15[episode_starts[i]:episode_ends[i]] for i, j in enumerate(episode_starts)]).astype(np.float32)})#[scaled_rewards > r_min]})

        iql_agent_ot = Learner(
                wandb.config.seed,
                env.observation_space.sample()[np.newaxis],
                env.action_space.sample()[np.newaxis],
                max_steps=wandb.config.max_steps,
                expectile=wandb.config.expectile,
                discount=wandb.config.discount,
                temperature=wandb.config.temperature)

        pbar = tqdm(range(wandb.config.max_steps))
        expert = OTRewardsExpert(expert_trajectory)

        for i in pbar:
            sample = data_with_ot_rewards.sample(batch_size)
            batch = Batch(
                observations=sample["observations"],
                next_observations=sample["next_observations"],
                actions = sample['actions'],
                rewards= sample["rewards"],
                masks= sample["masks"]
            )
            update_info = iql_agent_ot.update(batch)
            update_info['adv'] = None
            if i % 50_000 == 0 and i > 0:
                eval_stats = evaluate(iql_agent_ot, env, num_episodes=10)
                print(eval_stats)
                eval_stats['return'] = env.get_normalized_score(eval_stats['return'])*100
                wandb.log({f"Eval/{key}": value for key, value in eval_stats.items()})
                pbar.set_postfix(update_info)
            if i % 2000 == 0:
                wandb.log({f"Training/{key}": value for key, value in update_info.items()})
                pbar.set_postfix(update_info)
        print(f"Finished Training")
        print(f"Final Evaluation")
        # final eval
        eval_stats = evaluate(iql_agent_ot, env, num_episodes=100)
        print(eval_stats)


if __name__ == '__main__':
    main()
