"""
    Based on osrl-lib by Zuxin Liu and Zijian Guo (https://github.com/liuzuxin/OSRL.git), licensed under Apache 2.0 and MIT.
"""

import os
from pathlib import Path
import random
import time
import types
from dataclasses import asdict
from typing import List, Tuple
import numpy as np
import pyrallis
try:
    from isaacgym import gymapi, gymutil
except:
    pass
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
from safepo.common.env import make_ma_multi_goal_env, make_ma_mujoco_env, make_ma_isaac_env
from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from core.common.exp_util import auto_name, seed_all
from core.common.ma_dataset import *
from safepo.utils.config import multi_agent_args, parse_sim_params
from core.algorithms import CDT, CDTTrainer
from core.common import SequenceDataset
from torch.utils.tensorboard import SummaryWriter


class Runner:
    def __init__(self, args, agent_num, train_agent_num):
        self.args = args        
        self.agent_num = agent_num
        self.train_agent_num = train_agent_num

        if "MultiGoal" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_multi_goal_env(self.args.args_env.task, self.args.cfg_eval['seed'], self.args.cfg_eval)
        elif "Velocity" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_mujoco_env(self.args.args_env.scenario, self.args.args_env.agent_conf, self.args.cfg_eval['seed'], self.args.cfg_eval)
        else:
            agent_index = [[[0, 1, 2, 3, 4, 5]], [[0, 1, 2, 3, 4, 5]]]
            sim_params = parse_sim_params(self.args.args_env, self.args.cfg_env, self.args.cfg_eval)
            self.eval_env = make_ma_isaac_env(self.args.args_env, self.args.cfg_env, self.args.cfg_eval, sim_params, agent_index)

        self.episode_len = self.args.cfg_eval['episode_length']
        if self.args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
            self.rollout_threads = self.args.cfg_env["env"]["numEnvs"]
        else:
            self.rollout_threads = self.args.cfg_eval["n_rollout_threads"]
        self.env_max_act = [self.eval_env.action_space[agent].high.astype(np.float32) for agent in range(self.agent_num)]
        self.env_min_act = [self.eval_env.action_space[agent].low.astype(np.float32) for agent in range(self.agent_num)]
        self.act_dim_agent = [self.eval_env.action_space[agent].shape[0] for agent in range(self.agent_num)]
        if self.args.centralized_training:
            self.env_max_act = [np.concatenate(self.env_max_act)]
            self.env_min_act = [np.concatenate(self.env_min_act)]
        self.max_act = [np.maximum(np.abs(self.env_max_act[train_agent]), np.abs(self.env_min_act[train_agent])) for train_agent in range(self.train_agent_num)]
        self.env_max_act = [torch.as_tensor(self.env_max_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]
        self.env_min_act = [torch.as_tensor(self.env_min_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]
        self.max_act = [torch.as_tensor(self.max_act[train_agent], device=self.args.device) for train_agent in range(self.train_agent_num)]

        if self.args.centralized_training:
            self.data = [get_data_CTCE(self.args.agent_data_file)]
        else:
            self.data = []
            for data_file in self.args.agent_data_file:
                self.data.append(get_data_DTDE(data_file))
        self.state_dim = [d['observations'].shape[-1] for d in self.data]
        self.action_dim = [d['actions'].shape[-1] for d in self.data]
        
        self.trainer = []
        num_params = 0
        num_params_exe = 0
        for train_agent in range(self.train_agent_num):
            model = CDT(
                state_dim=self.state_dim[train_agent],
                action_dim=self.action_dim[train_agent],
                max_action=self.max_act[train_agent],
                embedding_dim=self.args.embedding_dim,
                seq_len=self.args.seq_len,
                episode_len=self.episode_len,
                num_layers=self.args.num_layers,
                num_heads=self.args.num_heads,
                attention_dropout=self.args.attention_dropout,
                residual_dropout=self.args.residual_dropout,
                embedding_dropout=self.args.embedding_dropout,
                time_emb=self.args.time_emb,
                use_rew=self.args.use_rew,
                use_cost=self.args.use_cost,
                cost_transform=self.args.cost_transform,
                add_cost_feat=self.args.add_cost_feat,
                mul_cost_feat=self.args.mul_cost_feat,
                cat_cost_feat=self.args.cat_cost_feat,
                action_head_layers=self.args.action_head_layers,
                cost_prefix=self.args.cost_prefix,
                stochastic=self.args.stochastic,
                init_temperature=self.args.init_temperature,
                target_entropy=-self.action_dim[train_agent],
            ).to(args.device)
            num_params += sum(p.numel() for p in model.parameters())
            num_params_exe += sum(p.numel() for p in model.parameters()) - \
                              sum(p.numel() for p in model.cost_pred_head.parameters()) - \
                              sum(p.numel() for p in model.state_pred_head.parameters())

            logger = SummaryWriter(log_dir=os.path.join(self.args.logdir_agent[train_agent], 'tb'))

            trainer = CDTTrainer(
                model,
                self.eval_env,
                logger=logger,
                learning_rate=self.args.learning_rate,
                weight_decay=self.args.weight_decay,
                betas=self.args.betas,
                clip_grad=self.args.clip_grad,
                lr_warmup_steps=self.args.lr_warmup_steps,
                reward_scale=self.args.reward_scale,
                cost_scale=self.args.cost_scale,
                loss_cost_weight=self.args.loss_cost_weight,
                loss_state_weight=self.args.loss_state_weight,
                cost_reverse=self.args.cost_reverse,
                no_entropy=self.args.no_entropy,
                device=self.args.device
            )
            self.trainer.append(trainer)
        print(f"Total parameters: {num_params}")
        print(f"Exe parameters: {num_params_exe}")

    @torch.no_grad()
    def evaluate(self, target_return, target_cost):
        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].model.eval()
        episode_rets, episode_costs, episode_lens, meants = [], [], [], []
        for _ in range(self.args.eval_episodes):
            epi_ret, epi_len, epi_cost, meant = self.rollout(target_return, target_cost)
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
            meants.append(meant)
        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].model.train()
        return np.mean(episode_rets) / self.args.reward_scale, np.mean(episode_costs) / self.args.cost_scale, np.mean(episode_lens), np.mean(meants)

    @torch.no_grad()
    def rollout(
        self,
        target_return: List[float],
        target_cost: List[float],
    ) -> Tuple[float, float]:
        obs, _, _ = self.eval_env.reset()
        if "FreightFranka" in self.args.task:
            obs = [torch.as_tensor(o, dtype=torch.float, device=self.args.device) for o in obs]
        else:
            obs = torch.as_tensor(obs, dtype=torch.float, device=self.args.device)
        if self.args.centralized_training:
            if "FreightFranka" in self.args.task:
                obs = torch.cat(obs, dim=1).unsqueeze(1)
            else:
                obs = obs.reshape(self.rollout_threads, 1, -1)

        states, actions, returns, costs, time_steps = [], [], [], [], []
        for train_agent in range(self.train_agent_num):
            state = torch.zeros(self.rollout_threads, self.episode_len + 1, self.state_dim[train_agent], dtype=torch.float, device=self.args.device)
            if "FreightFranka" in self.args.task and not self.args.centralized_training:
                state[:, 0, :] = obs[train_agent]
            else:
                state[:, 0, :] = obs[:, train_agent]
            states.append(state)
            ret = torch.zeros(self.rollout_threads, self.episode_len + 1, dtype=torch.float, device=self.args.device)
            ret[:, 0] = target_return[train_agent]
            returns.append(ret)
            cost = torch.zeros(self.rollout_threads, self.episode_len + 1, dtype=torch.float, device=self.args.device)
            cost[:, 0] = target_cost[train_agent]
            costs.append(cost)
            actions.append(torch.zeros(self.rollout_threads, self.episode_len, self.action_dim[train_agent], dtype=torch.float, device=self.args.device))
            time_steps.append(torch.arange(self.episode_len, dtype=torch.long, device=self.args.device).repeat(self.rollout_threads, 1))
        epi_cost = torch.tensor(np.array([target_cost]), dtype=torch.float, device=self.args.device)

        episode_ret = torch.zeros([self.rollout_threads, self.train_agent_num], dtype=torch.float, device=self.args.device)
        episode_cost = torch.zeros_like(episode_ret)
        epi_len = torch.zeros(self.rollout_threads, dtype=torch.float, device=self.args.device)
        active_mask = torch.ones(self.rollout_threads, dtype=torch.bool, device=self.args.device)
        ts = []

        for step in range(self.episode_len):
            acts = []
            for train_agent in range(self.train_agent_num):
                # first select history up to step, then select last seq_len states,
                # step + 1 as : operator is not inclusive, last action is dummy with zeros
                # (as model will predict last, actual last values are not important) # fix this noqa!!!
                s = states[train_agent][:, :step + 1][:, -self.args.seq_len:] # noqa
                a = actions[train_agent][:, :step + 1][:, -self.args.seq_len:] # noqa
                r = returns[train_agent][:, :step + 1][:, -self.args.seq_len:] # noqa
                c = costs[train_agent][:, :step + 1][:, -self.args.seq_len:] # noqa
                t = time_steps[train_agent][:, :step + 1][:, -self.args.seq_len:] # noqa

                start = time.perf_counter()
                act, _, _ = self.trainer[train_agent].model(s, a, r, c, t, None)
                end = time.perf_counter()
                ts.append((end - start) * 1000)

                if self.args.stochastic:
                    act = act.mean
                act = act.clamp(self.env_min_act[train_agent], self.env_max_act[train_agent])
                acts.append(act[:, -1].to(self.args.cfg_eval["device"]))

            if self.args.centralized_training:
                acts = acts[0]
                acts = list(torch.split(acts, self.act_dim_agent, dim=1))
            if self.args.task == "Safety98HumanoidVelocity-v0":
                padding = torch.zeros(acts[-1].shape[0], 1, device=acts[-1].device, dtype=acts[-1].dtype)
                acts[-1] = torch.cat((acts[-1], padding), dim=1)
            obs_next, _, reward, cost, terminated, _, _ = self.eval_env.step(acts)
            if "FreightFranka" in self.args.task:
                obs_next = [torch.as_tensor(o, device=self.args.device, dtype=torch.float) for o in obs_next]
            else:
                obs_next = torch.as_tensor(obs_next, device=self.args.device, dtype=torch.float)
            reward = torch.as_tensor(reward[:, :, 0], device=self.args.device, dtype=torch.float) * self.args.reward_scale
            cost = torch.as_tensor(cost[:, :, 0], device=self.args.device, dtype=torch.float) * self.args.cost_scale
            terminated = torch.as_tensor(terminated, device=self.args.device)
            terminated = torch.all(terminated, dim=1)
            if "FreightFranka" in self.args.task:
                acts = [act.to(self.args.device) for act in acts]
            else:
                acts = torch.stack(acts, dim=1).to(self.args.device)
            if self.args.centralized_training:
                if "FreightFranka" in self.args.task:
                    acts = torch.cat(acts, dim=1).unsqueeze(1)
                    obs_next = torch.cat(obs_next, dim=1).unsqueeze(1)
                else:
                    acts = acts.reshape(self.rollout_threads, 1, -1)
                    obs_next = obs_next.reshape(self.rollout_threads, 1, -1)
                reward = reward.mean(dim=1, keepdim=True)
                cost = cost.mean(dim=1, keepdim=True) 
                
            # at step t, we predict a_t, get s_{t + 1}, r_{t + 1}
            for train_agent in range(self.train_agent_num):
                if self.args.task == "Safety98HumanoidVelocity-v0" and train_agent == self.train_agent_num - 1:
                    actions[train_agent][:, step] = acts[:, train_agent, :-1]
                elif "FreightFranka" in self.args.task and not self.args.centralized_training:
                    actions[train_agent][:, step] = acts[train_agent]
                else:
                    actions[train_agent][:, step] = acts[:, train_agent]
                if "FreightFranka" in self.args.task and not self.args.centralized_training:
                    states[train_agent][:, step + 1] = obs_next[train_agent]
                else:
                    states[train_agent][:, step + 1] = obs_next[:, train_agent]
                returns[train_agent][:, step + 1] = returns[train_agent][:, step] - reward[:, train_agent]
                costs[train_agent][:, step + 1] = costs[train_agent][:, step] - cost[:, train_agent]

            episode_ret[active_mask, :] += reward[active_mask, :]
            episode_cost[active_mask, :] += cost[active_mask, :]
            epi_len[active_mask] += 1
            active_mask = torch.logical_and(active_mask, ~terminated)
            
            if torch.all(~active_mask).item():
                break
        
        episode_ret = episode_ret.mean().item()
        episode_cost = episode_cost.mean().item()
        epi_len = epi_len.mean().item()
        meant = np.mean(ts)

        return episode_ret, epi_len, episode_cost, meant

    def run(self):
        # for saving the best
        trainloader_iters = []
        target_returns = []
        for train_agent in range(self.train_agent_num):
            ret_span, target_return, bin_size = analyze_data(self.data[train_agent], self.args.cost_limit)
            target_returns.append(target_return)
            ct = lambda x: self.args.ct_max - x if self.args.linear else 1 / (x + 10)
            max_reward = ret_span * 0.1
            min_reward = ret_span * 0.001
            max_rew_decrease = max_reward

            data = SequenceDataset(
                self.data[train_agent],
                seq_len=self.args.seq_len,
                reward_scale=self.args.reward_scale,
                cost_scale=self.args.cost_scale,
                deg=self.args.deg,
                pf_sample=self.args.pf_sample,
                max_rew_decrease=max_rew_decrease,
                beta=self.args.beta,
                augment_percent=self.args.augment_percent,
                cost_reverse=self.args.cost_reverse,
                max_reward=max_reward,
                min_reward=min_reward,
                pf_only=self.args.pf_only,
                rmin=self.args.rmin,
                cost_bins=bin_size,
                npb=self.args.npb,
                cost_sample=self.args.cost_sample,
                cost_transform=ct,
                start_sampling=self.args.start_sampling,
                prob=self.args.prob,
                random_aug=self.args.random_aug,
                aug_rmin=self.args.aug_rmin,
                aug_rmax=self.args.aug_rmax,
                aug_cmin=self.args.aug_cmin,
                aug_cmax=self.args.aug_cmax,
                cgap=self.args.cgap,
                rstd=self.args.rstd,
                cstd=self.args.cstd,
                task=self.args.task,
                max_npb=self.args.max_npb,
                min_npb=self.args.min_npb,
            )

            trainloader = DataLoader(
                data,
                batch_size=self.args.batch_size,
                num_workers=self.args.num_workers,
                pin_memory=False,
            )
            trainloader_iter = iter(trainloader)
            trainloader_iters.append(trainloader_iter)

        if not self.args.centralized_training:
            target_returns = np.squeeze(np.array(target_returns)).transpose(1, 2, 0).tolist()
        else:
            target_returns = target_returns[0]
        print(f'target_returns: {target_returns}')
        best_safe_return = [-np.inf for _ in range(len(target_returns))]
        best_safe_cost = [np.inf for _ in range(len(target_returns))]
        best_safe_step = [-1 for _ in range(len(target_returns))]
        best_return = [-np.inf for _ in range(len(target_returns))]
        best_cost = [np.inf for _ in range(len(target_returns))]
        best_step = [-1 for _ in range(len(target_returns))]
        meants, num_eval, train_time = 0, 0, 0

        for step in trange(self.args.update_steps):
            start = time.perf_counter()
            for train_agent in range(self.train_agent_num):
                batch = next(trainloader_iters[train_agent])
                states, actions, returns, cost_returns, time_steps, mask, costs = [b.to(self.args.device) for b in batch]
                returns *= self.args.reward_scale
                cost_returns *= self.args.cost_scale
                self.trainer[train_agent].train_one_step(states, actions, returns, cost_returns, time_steps, mask, costs, step)
            end = time.perf_counter()
            train_time += (end - start) * 1000

            # Eval
            if (step + 1) % self.args.eval_every == 0 or (step + 1) == self.args.update_steps:
                average_reward, average_cost = [], []
                for target_i, target_return in enumerate(target_returns):
                    reward_return, cost_return = target_return
                    ret, cost, length, meant = self.evaluate(
                        [rr * self.args.reward_scale for rr in reward_return],
                        [cr * self.args.cost_scale for cr in cost_return],
                    )
                    average_cost.append(cost)
                    average_reward.append(ret)
                    meants += meant
                    num_eval += 1
                    if cost <= self.args.cost_limit and ret > best_safe_return[target_i]:
                        best_safe_return[target_i] = ret
                        best_safe_cost[target_i] = cost
                        best_safe_step[target_i] = step
                        if self.args.save_model: 
                            for train_agent in range(self.train_agent_num):
                                torch.save(self.trainer[train_agent].model.state_dict(), self.args.save_model_path[train_agent])
                    if cost < best_cost[target_i] or (cost == best_cost[target_i] and ret > best_return[target_i]):
                        best_return[target_i] = ret
                        best_cost[target_i] = cost
                        best_step[target_i] = step

                    for train_agent in range(self.train_agent_num):
                        name = f"c_{cost_return[train_agent]:.2f}_r_{reward_return[train_agent]:.2f}"
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/return', ret, step)
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/cost_return', cost, step)
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/length', length, step)
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/best_safe_return', best_safe_return[target_i], step)
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/best_safe_cost_return', best_safe_cost[target_i], step)
                        self.trainer[train_agent].logger.add_scalar(f'eval_{name}/best_safe_step', best_safe_step[target_i], step)

        print(f"Train time: {train_time / self.args.update_steps} ms")  
        print(f"Exe time: {meants / num_eval / self.train_agent_num} ms")
        for train_agent in range(self.train_agent_num):
            self.trainer[train_agent].logger.close()
                 
@pyrallis.wrap()
def main(args: CDTTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(CDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    print('-' * 20)
    print(f'Task: {args.task}')
    
    if args.density > 1:
        args.density = 1 / args.density
    if args.group is None:
        args.group = args.task + "-cost" + str(int(args.cost_limit))
    args.data_dir = f"../MOSDB/{args.task}"
    agent_num = sum("agent_" in dir and "_sg" not in dir for dir in os.listdir(args.data_dir))
    if args.centralized_training:
        train_agent_num = 1
    else:
        train_agent_num = agent_num
    
    args.logdir = os.path.join(args.log_root_dir, args.task, args.prefix)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    exp_num = len([item for item in Path(args.logdir).iterdir() if item.is_dir()])
    if args.name is None:
        args.name = f"experiment{exp_num}"
    args.logdir = os.path.join(args.logdir, args.name)
    
    args.logdir_agent, args.agent_data_file, args.save_model_path = [], [], []
    for train_agent in range(train_agent_num):
        logdir_agent = os.path.join(args.logdir, "agent_{}".format(train_agent))
        save_model_dir = os.path.join(logdir_agent, "model")
        if not os.path.exists(save_model_dir):
            os.makedirs(save_model_dir)
        save_model_path = os.path.join(save_model_dir, "best_safe_model.pth")
        args.save_model_path.append(save_model_path)
        args.logdir_agent.append(logdir_agent)
    for agent in range(agent_num):
        args.agent_data_file.append(os.path.join(args.data_dir, f'agent_{agent}.h5'))
    # args.agent_data_file = [os.path.join(args.data_dir, dir) for dir in os.listdir(args.data_dir) if "agent_" in dir and "_sg" not in dir]

    args_env, cfg_env, cfg_eval = multi_agent_args(algo="mappolag")
    if args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
        cfg_env["env"]["numEnvs"] = 10
        # cfg_env["env"]["numEnvs"] = 1
    args_env.seed = 10000
    cfg_eval["seed"] = args_env.seed
    cfg_eval["n_rollout_threads"] = cfg_eval["n_eval_rollout_threads"]
    # cfg_eval["n_rollout_threads"] = 1
    args_env.cost_limit = args.cost_limit
    cfg_eval["cost_limit"] = args.cost_limit
    args.args_env, args.cfg_eval, args.cfg_env = args_env, cfg_eval, cfg_env

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    runner = Runner(args, agent_num, train_agent_num)
    runner.run()

if __name__ == "__main__":
    main()