"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

import os
from pathlib import Path
import random
import types
from dataclasses import asdict, dataclass
from typing import Tuple
import numpy as np
import pyrallis
try:
    from isaacgym import gymapi, gymutil
except:
    pass
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
from safepo.common.env import make_ma_multi_goal_env, make_ma_mujoco_env, make_ma_isaac_env
from examples.configs.bcql_configs import BCQL_DEFAULT_CONFIG, BCQLTrainConfig
from core.common import TransitionDataset
from core.common.exp_util import auto_name, seed_all
from core.common.ma_dataset import *
from safepo.utils.config import multi_agent_args, parse_sim_params
from core.algorithms import BCQL, BCQLTrainer
from core.common import TransitionDataset
from torch.utils.tensorboard import SummaryWriter


class Runner:
    def __init__(self, args, agent_num, train_agent_num):
        self.args = args        
        self.agent_num = agent_num
        self.train_agent_num = train_agent_num

        if "MultiGoal" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_multi_goal_env(self.args.args_env.task, self.args.cfg_eval['seed'], self.args.cfg_eval)
        elif "Velocity" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_mujoco_env(self.args.args_env.scenario, self.args.args_env.agent_conf, self.args.cfg_eval['seed'], self.args.cfg_eval)
        else:
            agent_index = [[[0, 1, 2, 3, 4, 5]], [[0, 1, 2, 3, 4, 5]]]
            sim_params = parse_sim_params(self.args.args_env, self.args.cfg_env, self.args.cfg_eval)
            self.eval_env = make_ma_isaac_env(self.args.args_env, self.args.cfg_env, self.args.cfg_eval, sim_params, agent_index)

        self.episode_len = self.args.cfg_eval['episode_length']
        if self.args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
            self.rollout_threads = self.args.cfg_env["env"]["numEnvs"]
        else:
            self.rollout_threads = self.args.cfg_eval["n_rollout_threads"]
        self.env_max_act = [self.eval_env.action_space[agent].high.astype(np.float32) for agent in range(self.agent_num)]
        self.env_min_act = [self.eval_env.action_space[agent].low.astype(np.float32) for agent in range(self.agent_num)]
        self.act_dim_agent = [self.eval_env.action_space[agent].shape[0] for agent in range(self.agent_num)]
        if self.args.centralized_training:
            self.env_max_act = [np.concatenate(self.env_max_act)]
            self.env_min_act = [np.concatenate(self.env_min_act)]
        self.max_act = [np.maximum(np.abs(self.env_max_act[train_agent]), np.abs(self.env_min_act[train_agent])) for train_agent in range(self.train_agent_num)]
        self.env_max_act = [torch.as_tensor(self.env_max_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]
        self.env_min_act = [torch.as_tensor(self.env_min_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]
        self.max_act = [torch.as_tensor(self.max_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]
        
        if self.args.centralized_training:
            self.data = get_data_CTCE(self.args.agent_data_file)
        else:
            samples = []
            for train_agent in range(self.train_agent_num):
                arct_file = os.path.join(self.args.arct_data_dir[train_agent], self.file_list[0])
                obs_file = os.path.join(self.args.obs_data_dir[train_agent], self.file_list[0])
                sample = get_file(arct_file, obs_file)
                samples.append(sample)
        self.state_dim = [self.data['observations'].shape[-1]]
        self.action_dim = [self.data['actions'].shape[-1]]
        
        self.trainer = []
        num_params = 0
        for train_agent in range(self.train_agent_num):
            model = BCQL(
                state_dim=self.state_dim[train_agent],
                action_dim=self.action_dim[train_agent],
                max_action=self.max_act[train_agent],
                a_hidden_sizes=self.args.a_hidden_sizes,
                c_hidden_sizes=self.args.c_hidden_sizes,
                vae_hidden_sizes=self.args.vae_hidden_sizes,
                sample_action_num=self.args.sample_action_num,
                PID=self.args.PID,
                gamma=self.args.gamma,
                tau=self.args.tau,
                lmbda=self.args.lmbda,
                beta=self.args.beta,
                phi=self.args.phi,
                num_q=self.args.num_q,
                num_qc=self.args.num_qc,
                cost_limit=self.args.cost_limit * self.args.cost_scale,
                episode_len=self.episode_len,
                device=self.args.device,
            )
            num_params += sum(p.numel() for p in model.parameters())

            logger = SummaryWriter(log_dir=os.path.join(self.args.logdir_agent[train_agent], 'tb'))

            trainer = BCQLTrainer(
                model,
                self.eval_env,
                logger=logger,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                vae_lr=self.args.vae_lr,
                reward_scale=self.args.reward_scale,
                cost_scale=self.args.cost_scale,
                device=self.args.device,
            )
            self.trainer.append(trainer)
        print(f"Total parameters: {num_params}")

    @torch.no_grad()
    def evaluate(self):
        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].model.eval()
        episode_rets, episode_costs, episode_lens = [], [], []
        for _ in range(self.args.eval_episodes):
            epi_ret, epi_len, epi_cost = self.rollout()
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].model.train()
        return np.mean(episode_rets) / self.args.reward_scale, np.mean(episode_costs) / self.args.cost_scale, np.mean(episode_lens)

    @torch.no_grad()
    def rollout(self) -> Tuple[float, float]:
        obs, _, _ = self.eval_env.reset()
        if "FreightFranka" in self.args.task:
            obs = [torch.as_tensor(o, dtype=torch.float, device=self.args.device) for o in obs]
        else:
            obs = torch.as_tensor(obs, dtype=torch.float, device=self.args.device)
        if self.args.centralized_training:
            if "FreightFranka" in self.args.task:
                obs = torch.cat(obs, dim=1).unsqueeze(1)
            else:
                obs = obs.reshape(self.rollout_threads, 1, -1)

        episode_ret = torch.zeros([self.rollout_threads, self.train_agent_num], dtype=torch.float, device=self.args.device)
        episode_cost = torch.zeros_like(episode_ret)
        epi_len = torch.zeros(self.rollout_threads, dtype=torch.float, device=self.args.device)
        active_mask = torch.ones(self.rollout_threads, dtype=torch.bool, device=self.args.device)

        for _ in range(self.episode_len):
            acts = []
            for train_agent in range(self.train_agent_num):
                if "FreightFranka" in self.args.task and not self.args.centralized_training:
                    act = self.trainer[train_agent].model.actor(obs[train_agent], self.trainer[train_agent].model.vae.decode(obs[train_agent]))
                else:
                    act = self.trainer[train_agent].model.actor(obs[:, train_agent], self.trainer[train_agent].model.vae.decode(obs[:, train_agent]))
                acts.append(act.to(self.args.cfg_eval["device"]))

            if self.args.centralized_training:
                acts = acts[0]
                acts = list(torch.split(acts, self.act_dim_agent, dim=1))
            if self.args.task == "Safety98HumanoidVelocity-v0":
                padding = torch.zeros(acts[-1].shape[0], 1, device=acts[-1].device, dtype=acts[-1].dtype)
                acts[-1] = torch.cat((acts[-1], padding), dim=1)
            obs_next, _, reward, cost, terminated, _, _ = self.eval_env.step(acts)
            if "FreightFranka" in self.args.task:
                obs_next = [torch.as_tensor(o, device=self.args.device, dtype=torch.float) for o in obs_next]
            else:
                obs_next = torch.as_tensor(obs_next, device=self.args.device, dtype=torch.float)
            reward = torch.as_tensor(reward[:, :, 0], device=self.args.device, dtype=torch.float) * self.args.reward_scale
            cost = torch.as_tensor(cost[:, :, 0], device=self.args.device, dtype=torch.float) * self.args.cost_scale
            terminated = torch.as_tensor(terminated, device=self.args.device)
            terminated = torch.all(terminated, dim=1)
            if self.args.centralized_training:
                if "FreightFranka" in self.args.task:
                    obs_next = torch.cat(obs_next, dim=1).unsqueeze(1)
                else:
                    obs_next = obs_next.reshape(self.rollout_threads, 1, -1)
                reward = reward.mean(dim=1, keepdim=True)
                cost = cost.mean(dim=1, keepdim=True)
            obs = obs_next

            episode_ret[active_mask, :] += reward[active_mask, :]
            episode_cost[active_mask, :] += cost[active_mask, :]
            epi_len[active_mask] += 1
            active_mask = torch.logical_and(active_mask, ~terminated)
            
            if torch.all(~active_mask).item():
                break
        
        episode_ret = episode_ret.mean().item()
        episode_cost = episode_cost.mean().item()
        epi_len = epi_len.mean().item()
        return episode_ret, epi_len, episode_cost

    def run(self):
        # for saving the best
        best_safe_return = -np.inf
        best_safe_cost = np.inf
        best_safe_step = -1
        best_return = -np.inf
        best_cost = np.inf
        best_step = -1

        for train_agent in range(self.train_agent_num):
            self.data = TransitionDataset(
                self.data,
                reward_scale=self.args.reward_scale,
                cost_scale=self.args.cost_scale,
            )

            trainloader = DataLoader(
                self.data,
                batch_size=self.args.batch_size,
                num_workers=self.args.num_workers,
                pin_memory=False,
            )
            trainloader_iter = iter(trainloader)

            for step in trange(self.args.update_steps):
                batch = next(trainloader_iter)
                observations, next_observations, actions, rewards, costs, done = [b.to(self.args.device) for b in batch]
                self.trainer[train_agent].train_one_step(observations, next_observations, actions, rewards, costs, done, step)

                # Eval
                if (step + 1) % self.args.eval_every == 0 or (step + 1) == self.args.update_steps:
                        ret, cost, length = self.evaluate()
                        if cost <= self.args.cost_limit and ret > best_safe_return:
                            best_safe_return = ret
                            best_safe_cost = cost
                            best_safe_step = step
                            if self.args.save_model: 
                                for train_agent in range(self.train_agent_num):
                                    torch.save(self.trainer[train_agent].model.state_dict(), self.args.save_model_path[train_agent])
                        if cost < best_cost or (cost == best_cost and ret > best_return):
                            best_return = ret
                            best_cost = cost
                            best_step = step

                        for train_agent in range(self.train_agent_num):
                            self.trainer[train_agent].logger.add_scalar(f'eval/return', ret, step)
                            self.trainer[train_agent].logger.add_scalar(f'eval/cost_return', cost, step)
                            self.trainer[train_agent].logger.add_scalar(f'eval/length', length, step)
                            self.trainer[train_agent].logger.add_scalar(f'eval/best_safe_return', best_safe_return, step)
                            self.trainer[train_agent].logger.add_scalar(f'eval/best_safe_cost_return', best_safe_cost, step)
                            self.trainer[train_agent].logger.add_scalar(f'eval/best_safe_step', best_safe_step, step)

        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].logger.close()
                 
@pyrallis.wrap()
def main(args: BCQLTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(BCQLTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(BCQL_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    print('-' * 20)
    print(f'Task: {args.task}')

    if args.density > 1:
        args.density = 1 / args.density
    if args.group is None:
        args.group = args.task + "-cost" + str(int(args.cost_limit))
    args.data_dir = f"../MOSDB/{args.task}"
    agent_num = sum("agent_" in dir and "_sg" not in dir for dir in os.listdir(args.data_dir))
    if args.centralized_training:
        train_agent_num = 1
    else:
        train_agent_num = agent_num
    
    args.logdir = os.path.join(args.log_root_dir, args.task, args.prefix)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    exp_num = len([item for item in Path(args.logdir).iterdir() if item.is_dir()])
    if args.name is None:
        args.name = f"experiment{exp_num}"
    args.logdir = os.path.join(args.logdir, args.name)
    
    args.logdir_agent, args.agent_data_file, args.save_model_path = [], [], []
    for train_agent in range(train_agent_num):
        logdir_agent = os.path.join(args.logdir, "agent_{}".format(train_agent))
        save_model_dir = os.path.join(logdir_agent, "model")
        if not os.path.exists(save_model_dir):
            os.makedirs(save_model_dir)
        save_model_path = os.path.join(save_model_dir, "best_safe_model.pth")
        args.save_model_path.append(save_model_path)
        args.logdir_agent.append(logdir_agent)
    for agent in range(agent_num):
        args.agent_data_file.append(os.path.join(args.data_dir, f'agent_{agent}.h5'))

    args_env, cfg_env, cfg_eval = multi_agent_args(algo="mappolag")
    if args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
        cfg_env["env"]["numEnvs"] = 10
    args_env.seed = 10000
    cfg_eval["seed"] = args_env.seed
    cfg_eval["n_rollout_threads"] = cfg_eval["n_eval_rollout_threads"]
    args_env.cost_limit = args.cost_limit
    cfg_eval["cost_limit"] = args.cost_limit
    args.args_env, args.cfg_eval, args.cfg_env = args_env, cfg_eval, cfg_env

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    runner = Runner(args, agent_num, train_agent_num)
    runner.run()

if __name__ == "__main__":
    main()