import os
from pathlib import Path
import random
import time
import types
from dataclasses import asdict
from typing import Tuple, List
import numpy as np
import pyrallis
try:
    from isaacgym import gymapi, gymutil
except:
    pass
import torch
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
from safepo.common.env import make_ma_multi_goal_env, make_ma_mujoco_env, make_ma_isaac_env
from examples.configs.mosdt_configs import MOSDT_DEFAULT_CONFIG, MOSDTTrainConfig
from core.common.exp_util import auto_name, seed_all
from core.common.ma_dataset import *
from safepo.utils.config import multi_agent_args, parse_sim_params
from core.algorithms import MOSDT, MOSDTTrainerDev
from core.common import SequenceDataset
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F


class Runner:
    def __init__(self, args, agent_num):
        self.args = args        
        self.agent_num = agent_num

        if "MultiGoal" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_multi_goal_env(self.args.args_env.task, self.args.cfg_eval['seed'], self.args.cfg_eval)
        elif "Velocity" in self.args.task:
            self.args.cfg_eval["device"] = "cpu"
            self.eval_env = make_ma_mujoco_env(self.args.args_env.scenario, self.args.args_env.agent_conf, self.args.cfg_eval['seed'], self.args.cfg_eval)
        else:
            agent_index = [[[0, 1, 2, 3, 4, 5]], [[0, 1, 2, 3, 4, 5]]]
            sim_params = parse_sim_params(self.args.args_env, self.args.cfg_env, self.args.cfg_eval)
            self.eval_env = make_ma_isaac_env(self.args.args_env, self.args.cfg_env, self.args.cfg_eval, sim_params, agent_index)

        self.episode_len = self.args.cfg_eval['episode_length']
        self.args.seq_len = min(self.args.seq_len, self.episode_len)
        if self.args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
            self.rollout_threads = self.args.cfg_env["env"]["numEnvs"]
        else:
            self.rollout_threads = self.args.cfg_eval["n_rollout_threads"]
        
        self.data, self.state_dim, self.action_dim = get_data_CTDE(self.args.agent_data_file, same_r=self.args.same_r, same_c=self.args.same_c)

        self.env_max_act = [self.eval_env.action_space[agent].high.astype(np.float32) for agent in range(self.agent_num)]
        self.env_min_act = [self.eval_env.action_space[agent].low.astype(np.float32) for agent in range(self.agent_num)]

        self.max_act = []
        for agent in range(self.agent_num):
            ma = np.maximum(np.abs(self.env_max_act[agent]), np.abs(self.env_min_act[agent]))
            ma = torch.as_tensor(ma, device=self.args.device)
            ma = F.pad(ma, (0, max(self.action_dim) - len(ma)), "constant", 0)
            self.max_act.append(ma)
        self.env_max_act = [torch.as_tensor(self.env_max_act[agent], device=self.args.device) for agent in range(self.agent_num)]
        self.env_min_act = [torch.as_tensor(self.env_min_act[agent], device=self.args.device) for agent in range(self.agent_num)]

        ret_span, self.target_returns, self.bin_size, self.ret_mean, self.ret_std = analyze_data_CTDE(self.data, self.agent_num, self.args.cost_limit)
        print(f'target_returns: {self.target_returns}')
        self.ct = lambda x: self.args.ct_max - x if self.args.linear else 1 / (x + 10)
        self.max_reward = ret_span * 0.1
        self.min_reward = ret_span * 0.001
        self.max_rew_decrease = self.max_reward
        
        self.model = MOSDT(
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            max_action=self.max_act,
            embedding_dim=self.args.embedding_dim,
            seq_len=self.args.seq_len,
            episode_len=self.episode_len,
            num_layers=self.args.num_layers,
            total_layers=self.args.total_layers,
            num_heads=self.args.num_heads,
            num_teacher_heads=self.args.num_teacher_heads,
            attention_dropout=self.args.attention_dropout,
            residual_dropout=self.args.residual_dropout,
            embedding_dropout=self.args.embedding_dropout,
            time_emb=self.args.time_emb,
            cost_transform=self.args.cost_transform,
            add_cost_feat=self.args.add_cost_feat,
            mul_cost_feat=self.args.mul_cost_feat,
            cat_cost_feat=self.args.cat_cost_feat,
            action_head_layers=self.args.action_head_layers,
            stochastic=self.args.stochastic,
            init_temperature=self.args.init_temperature,
            target_entropy=[-ad for ad in self.action_dim],
            fix_att=self.args.fix_att,
            decision_heads=self.args.decision_heads,
            cost_classify=self.args.cost_classify,
            ps=self.args.ps,
            be=self.args.be,
        ).to(args.device)

        print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1000000} M")
        exe_param_num = sum(p.numel() for p in self.model.emb_norm.parameters()) + \
                        sum(p.numel() for p in self.model.out_norm.parameters()) + \
                        sum(p.numel() for p in self.model.timestep_emb.parameters()) + \
                        sum(p.numel() for p in self.model.state_ember.parameters()) + \
                        sum(p.numel() for p in self.model.action_ember.parameters()) + \
                        sum(p.numel() for p in self.model.cost_ember.parameters()) + \
                        sum(p.numel() for p in self.model.return_ember.parameters()) + \
                        sum(p.numel() for p in self.model.blocks.parameters()) + \
                        sum(p.numel() for p in self.model.feat_porj.parameters()) + \
                        sum(p.numel() for p in self.model.action_head.parameters())
        print(f"Exe parameters: {exe_param_num / 1000000} M")

        self.logger = SummaryWriter(log_dir=os.path.join(self.args.logdir, 'tb'))

        self.trainer = MOSDTTrainerDev(
            self.model,
            self.eval_env,
            logger=self.logger,
            learning_rate=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
            betas=self.args.betas,
            clip_grad=self.args.clip_grad,
            lr_warmup_steps=self.args.lr_warmup_steps,
            reward_scale=self.args.reward_scale,
            cost_scale=self.args.cost_scale,
            cost_reverse=self.args.cost_reverse,
            no_entropy=self.args.no_entropy,
            device=self.args.device,
            sd_weight=self.args.sd_weight,
            feat_sd_weight=self.args.feat_sd_weight,
            sd=self.args.sd,
        )

    @torch.no_grad()
    def evaluate(self, target_return, target_cost):
        self.model.eval()
        episode_rets, episode_costs, episode_lens, meants = [], [], [], []
        for _ in range(self.args.eval_episodes):
            epi_ret, epi_len, epi_cost, meant = self.rollout(target_return, target_cost)
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
            meants.append(meant)
        self.model.train()
        return np.mean(episode_rets) / self.args.reward_scale, np.mean(episode_costs) / self.args.cost_scale, np.mean(episode_lens), np.mean(meants)

    @torch.no_grad()
    def rollout(
        self,
        target_return: List[float],
        target_cost: List[float],
    ) -> Tuple[float, float]:
        obs, _, _ = self.eval_env.reset()
        if "FreightFranka" in self.args.task:
            obs = [torch.as_tensor(o, dtype=torch.float, device=self.args.device) for o in obs]
        else:
            obs = torch.as_tensor(obs, dtype=torch.float, device=self.args.device)

        states, actions, returns, costs = [], [], [], []
        for agent in range(self.agent_num):
            state = torch.zeros(self.rollout_threads, self.episode_len + 1, self.state_dim[agent], dtype=torch.float, device=self.args.device)
            if "FreightFranka" in self.args.task:
                state[:, 0, :] = obs[agent]
            else:
                state[:, 0, :] = obs[:, agent]
            states.append(state)
            ret = torch.zeros(self.rollout_threads, self.episode_len + 1, 1, dtype=torch.float, device=self.args.device)
            ret[:, 0] = target_return[agent].item()
            returns.append(ret)
            cost = torch.zeros(self.rollout_threads, self.episode_len + 1, 1, dtype=torch.float, device=self.args.device)
            cost[:, 0] = target_cost[agent]
            costs.append(cost)
            actions.append(torch.zeros(self.rollout_threads, self.episode_len, self.action_dim[agent], dtype=torch.float, device=self.args.device))

        time_steps = torch.arange(self.episode_len, dtype=torch.long, device=self.args.device).repeat(self.rollout_threads, 1)
        episode_ret = torch.zeros([self.rollout_threads, self.agent_num], dtype=torch.float, device=self.args.device)
        episode_cost = torch.zeros_like(episode_ret)
        epi_len = torch.zeros(self.rollout_threads, dtype=torch.float, device=self.args.device)
        active_mask = torch.ones(self.rollout_threads, dtype=torch.bool, device=self.args.device)
        ts = []

        for step in range(self.episode_len):
            s, a, r, c, t = [], [], [], [], []
            for agent in range(self.agent_num):
                s.append(states[agent][:, :step + 1][:, -self.args.seq_len:]) # noqa
                a.append(actions[agent][:, :step + 1][:, -self.args.seq_len:]) # noqa
                r.append(returns[agent][:, :step + 1][:, -self.args.seq_len:]) # noqa
                c.append(costs[agent][:, :step + 1][:, -self.args.seq_len:]) # noqa
            s = torch.cat(s, dim=-1)
            a = torch.cat(a, dim=-1)
            r = torch.cat(r, dim=-1)
            c = torch.cat(c, dim=-1)
            t = time_steps[:, :step + 1][:, -self.args.seq_len:] # noqa

            # for agent in range(self.agent_num):
            #     r[:, :, agent] = (r[:, :, agent] - self.ret_mean[agent]) / (self.ret_std[agent])
            if self.args.cost_classify:
                c[c <= self.args.cost_limit * self.args.cost_scale] = 0
                c[c > self.args.cost_limit * self.args.cost_scale] = 1

            start = time.perf_counter()
            all_act, _, _, _, _, _ = self.model(s, a, r, c, t, None)
            end = time.perf_counter()
            ts.append((end - start) * 1000)

            acts = []
            for agent in range(self.agent_num):
                if self.args.stochastic:
                    act = all_act[agent].mean
                else:
                    act = all_act[agent]
                act = act[:, :, :self.action_dim[agent]]
                act = act.clamp(self.env_min_act[agent], self.env_max_act[agent])
                acts.append(act[:, -1].to(self.args.cfg_eval["device"]))
            if self.args.task == "Safety98HumanoidVelocity-v0":
                padding = torch.zeros(acts[-1].shape[0], 1, device=acts[-1].device, dtype=acts[-1].dtype)
                acts[-1] = torch.cat((acts[-1], padding), dim=1)
            obs_next, _, reward, cost, terminated, _, _ = self.eval_env.step(acts)
            if "FreightFranka" in self.args.task:
                obs_next = [torch.as_tensor(o, device=self.args.device, dtype=torch.float) for o in obs_next]
            else:
                obs_next = torch.as_tensor(obs_next, device=self.args.device, dtype=torch.float)
            reward = torch.as_tensor(reward, device=self.args.device, dtype=torch.float) * self.args.reward_scale
            cost = torch.as_tensor(cost, device=self.args.device, dtype=torch.float) * self.args.cost_scale

            if self.args.same_r:
                reward = torch.mean(reward, dim=1, keepdim=True)
                reward = reward.repeat(1, self.agent_num, 1)
            if self.args.same_c:
                cost = torch.mean(cost, dim=1, keepdim=True)
                cost = cost.repeat(1, self.agent_num, 1)

            terminated = torch.as_tensor(terminated, device=self.args.device)
            terminated = torch.all(terminated, dim=1)
            if "FreightFranka" in self.args.task:
                acts = [act.to(self.args.device) for act in acts]
            else:
                acts = torch.stack(acts, dim=1).to(self.args.device)
                
            for agent in range(self.agent_num):
                if self.args.task == "Safety98HumanoidVelocity-v0" and agent == self.agent_num - 1:
                    actions[agent][:, step] = acts[:, agent, :-1]
                elif "FreightFranka" in self.args.task:
                    actions[agent][:, step] = acts[agent]
                else:
                    actions[agent][:, step] = acts[:, agent]
                if "FreightFranka" in self.args.task:
                    states[agent][:, step + 1] = obs_next[agent]
                else:
                    states[agent][:, step + 1] = obs_next[:, agent]
                returns[agent][:, step + 1] = returns[agent][:, step] - reward[:, agent]
                costs[agent][:, step + 1] = costs[agent][:, step] - cost[:, agent]

            episode_ret[active_mask, :] += reward[active_mask, :, 0]
            episode_cost[active_mask, :] += cost[active_mask, :, 0]
            epi_len[active_mask] += 1
            active_mask = torch.logical_and(active_mask, ~terminated)
            
            if torch.all(~active_mask).item():
                break
        
        episode_ret = episode_ret.mean().item()
        episode_cost = episode_cost.mean().item()
        epi_len = epi_len.mean().item()
        meant = np.mean(ts)

        return episode_ret, epi_len, episode_cost, meant

    def run(self):
        # for saving the best
        best_safe_return = [-np.inf for _ in range(len(self.target_returns))]
        best_safe_cost = [np.inf for _ in range(len(self.target_returns))]
        best_safe_step = [-1 for _ in range(len(self.target_returns))]
        best_return = [-np.inf for _ in range(len(self.target_returns))]
        best_cost = [np.inf for _ in range(len(self.target_returns))]
        best_step = [-1 for _ in range(len(self.target_returns))]

        self.data = SequenceDataset(
            self.data,
            seq_len=self.args.seq_len,
            reward_scale=self.args.reward_scale,
            cost_scale=self.args.cost_scale,
            deg=self.args.deg,
            pf_sample=self.args.pf_sample,
            max_rew_decrease=self.max_rew_decrease,
            beta=self.args.beta,
            augment_percent=self.args.augment_percent,
            cost_reverse=self.args.cost_reverse,
            max_reward=self.max_reward,
            min_reward=self.min_reward,
            pf_only=self.args.pf_only,
            rmin=self.args.rmin,
            cost_bins=self.bin_size,
            npb=self.args.npb,
            cost_sample=self.args.cost_sample,
            cost_transform=self.ct,
            start_sampling=self.args.start_sampling,
            prob=self.args.prob,
            random_aug=self.args.random_aug,
            aug_rmin=self.args.aug_rmin,
            aug_rmax=self.args.aug_rmax,
            aug_cmin=self.args.aug_cmin,
            aug_cmax=self.args.aug_cmax,
            cgap=self.args.cgap,
            rstd=self.args.rstd,
            cstd=self.args.cstd,
            task=self.args.task,
            max_npb=self.args.max_npb,
            min_npb=self.args.min_npb,
        )

        trainloader = DataLoader(
            self.data,
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=False,
        )
        trainloader_iter = iter(trainloader)
        meants, num_eval, train_time = 0, 0, 0

        for step in trange(self.args.update_steps):
            start = time.perf_counter()
            batch = next(trainloader_iter)
            states, actions, returns, cost_returns, time_steps, mask, _ = [b.to(self.args.device) for b in batch]

            if self.args.cost_classify:
                cost_returns[cost_returns <= self.args.cost_limit] = 0
                cost_returns[cost_returns > self.args.cost_limit] = 1
            # for agent in range(self.agent_num):
            #     returns[:, :, agent] = (returns[:, :, agent] - self.ret_mean[agent]) / (self.ret_std[agent])
            returns *= self.args.reward_scale
            cost_returns *= self.args.cost_scale
           
            self.trainer.train_one_step(states, actions, returns, cost_returns, time_steps, mask, step)
            end = time.perf_counter()
            train_time += (end - start) * 1000

            # Eval
            if (step + 1) % self.args.eval_every == 0 or (step + 1) == self.args.update_steps:
                for target_i, target_return in enumerate(self.target_returns):
                    reward_return, cost_return = target_return
                    ret, cost, length, meant = self.evaluate(
                        [rr * self.args.reward_scale for rr in reward_return],
                        [cr * self.args.cost_scale for cr in cost_return]
                    )
                    if cost <= self.args.cost_limit and ret > best_safe_return[target_i]:
                        best_safe_return[target_i] = ret
                        best_safe_cost[target_i] = cost
                        best_safe_step[target_i] = step
                        if self.args.save_model: 
                            torch.save(self.model.state_dict(), self.args.save_model_path)
                    if cost < best_cost[target_i] or (cost == best_cost[target_i] and ret > best_return[target_i]):
                        best_return[target_i] = ret
                        best_cost[target_i] = cost
                        best_step[target_i] = step

                    name = f"meanc_{sum(cost_return)/len(cost_return):.2f}_meanr_{sum(reward_return)/len(reward_return):.2f}"
                    self.logger.add_scalar(f'eval_{name}/return', ret, step)
                    self.logger.add_scalar(f'eval_{name}/cost_return', cost, step)
                    self.logger.add_scalar(f'eval_{name}/length', length, step)
                    self.logger.add_scalar(f'eval_{name}/best_safe_return', best_safe_return[target_i], step)
                    self.logger.add_scalar(f'eval_{name}/best_safe_cost_return', best_safe_cost[target_i], step)
                    self.logger.add_scalar(f'eval_{name}/best_safe_step', best_safe_step[target_i], step)
                    meants += meant
                    num_eval += 1
        
        print(f"Train time: {train_time / self.args.update_steps} ms")   
        print(f"Exe time: {meants / num_eval / self.agent_num} ms")            
        self.logger.close()
                 
@pyrallis.wrap()
def main(args: MOSDTTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(MOSDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(MOSDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    print('-' * 20)
    print(f'Task: {args.task}')

    if args.density > 1:
        args.density = 1 / args.density
    if args.group is None:
        args.group = args.task + "-cost" + str(int(args.cost_limit))
    args.data_dir = f"../MOSDB/{args.task}"
    agent_num = sum("agent_" in dir and "_sg" not in dir for dir in os.listdir(args.data_dir))
    
    args.logdir = os.path.join(args.log_root_dir, args.task)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    exp_num = len([item for item in Path(args.logdir).iterdir() if item.is_dir()])
    if args.name is None:
        args.name = f"experiment{exp_num}"
    args.logdir = os.path.join(args.logdir, args.name)
    
    save_model_dir = os.path.join(args.logdir, "model")
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    args.save_model_path = os.path.join(save_model_dir, "best_safe_model.pth")

    args.agent_data_file, args.agent_sg_file = [], []
    for agent in range(agent_num):
        args.agent_data_file.append(os.path.join(args.data_dir, f'agent_{agent}.h5'))
        args.agent_sg_file.append(os.path.join(args.data_dir, f'agent_{agent}_sg.h5'))

    args_env, cfg_env, cfg_eval = multi_agent_args(algo="mappolag")
    if args.task in [
            "ShadowHandOver_Safe_joint",
            "ShadowHandOver_Safe_finger",
            "ShadowHandCatchOver2Underarm_Safe_joint",
            "ShadowHandCatchOver2Underarm_Safe_finger",
            "FreightFrankaCloseDrawer",
            "FreightFrankaPickAndPlace",
        ]:
        cfg_env["env"]["numEnvs"] = 10
        # cfg_env["env"]["numEnvs"] = 1 # for testing eval time
    args_env.seed = 10000
    cfg_eval["seed"] = args_env.seed
    cfg_eval["n_rollout_threads"] = cfg_eval["n_eval_rollout_threads"]
    # cfg_eval["n_rollout_threads"] = 1 # for testing eval time
    args_env.cost_limit = args.cost_limit
    cfg_eval["cost_limit"] = args.cost_limit
    args.args_env, args.cfg_eval, args.cfg_env = args_env, cfg_eval, cfg_env

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    runner = Runner(args, agent_num)
    runner.run()

if __name__ == "__main__":
    main()