import argparse
import os

import diffuser.utils as utils

import torch
import yaml
from diffuser.utils.launcher_util import (
    build_config_from_dict,
    discover_latest_checkpoint_path,
)


from tqdm import tqdm
import psutil
import numpy as np
    
def main(Config, RUN):

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    utils.set_seed(Config.seed)
    dataset_extra_kwargs = dict()

    # configs that does not exist in old yaml files
    Config.discrete_action = getattr(Config, "discrete_action", False)
    Config.state_loss_weight = getattr(Config, "state_loss_weight", None)
    Config.opponent_loss_weight = getattr(Config, "opponent_loss_weight", None)
    Config.use_seed_dataset = getattr(Config, "use_seed_dataset", False)
    Config.residual_attn = getattr(Config, "residual_attn", True)
    Config.use_temporal_attention = getattr(Config, "use_temporal_attention", True)
    Config.env_ts_condition = getattr(Config, "env_ts_condition", False)
    Config.use_return_to_go = getattr(Config, "use_return_to_go", False)
    Config.joint_inv = getattr(Config, "joint_inv", False)
    Config.use_zero_padding = getattr(Config, "use_zero_padding", True)
    Config.use_inv_dyn = getattr(Config, "use_inv_dyn", True)
    Config.pred_future_padding = getattr(Config, "pred_future_padding", False)
    if not hasattr(Config, "agent_condition_type"):
        if Config.decentralized_execution:
            Config.agent_condition_type = "single"
        else:
            Config.agent_condition_type = "all"
    
#
    # ---------------------------------- dataset ----------------------------------#
    # -----------------------------------------------------------------------------#
    dataset_config = utils.Config(
        Config.loader,
        savepath="dataset_config.pkl",
        env_type=Config.env_type,
        env=Config.dataset,
        n_agents=Config.n_agents,
        horizon=Config.horizon,
        history_horizon=Config.history_horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        max_n_episodes=Config.max_n_episodes,
        use_padding=Config.use_padding,
        use_action=Config.use_action,
        discrete_action=Config.discrete_action,
        max_path_length=Config.max_path_length,
        include_returns=Config.returns_condition,
        include_env_ts=Config.env_ts_condition,
        returns_scale=Config.returns_scale,
        discount=Config.discount,
        termination_penalty=Config.termination_penalty,
        agent_share_parameters=utils.config.import_class(
            Config.model
        ).agent_share_parameters,
        use_seed_dataset=Config.use_seed_dataset,
        seed=Config.dataset_seed,
        use_inv_dyn=Config.use_inv_dyn,
        decentralized_execution=Config.decentralized_execution,
        use_zero_padding=Config.use_zero_padding,
        agent_condition_type=Config.agent_condition_type,
        pred_future_padding=Config.pred_future_padding,
        **dataset_extra_kwargs,
    )

    data_encoder_config = utils.Config(
        getattr(Config, "data_encoder", "utils.IdentityEncoder"),
        savepath="data_encoder_config.pkl",
    )

    dataset = dataset_config()
    data_encoder = data_encoder_config()
    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    # -----------------------------------------------------------------------------#
    # ------------------------------ model & trainer ------------------------------#
    # -----------------------------------------------------------------------------#
    model_config = utils.Config(
        Config.model,
        savepath="model_config.pkl",
        n_agents=Config.n_agents,
        horizon=Config.horizon + Config.history_horizon,
        history_horizon=Config.history_horizon,
        transition_dim=observation_dim,
        dim_mults=Config.dim_mults,
        returns_condition=Config.returns_condition,
        env_ts_condition=Config.env_ts_condition,
        dim=Config.dim,
        condition_dropout=Config.condition_dropout,
        residual_attn=Config.residual_attn,
        max_path_length=Config.max_path_length,
        use_temporal_attention=Config.use_temporal_attention,
        device=Config.device,
    )

    diffusion_config = utils.Config(
        Config.diffusion,
        savepath="diffusion_config.pkl",
        n_agents=Config.n_agents,
        horizon=Config.horizon,
        history_horizon=Config.history_horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        discrete_action=Config.discrete_action,
        num_actions=getattr(dataset.env, "num_actions", 0),
        n_timesteps=Config.n_diffusion_steps,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        train_only_inv=Config.train_only_inv,
        share_inv=Config.share_inv,
        joint_inv=Config.joint_inv,
        
        # loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        state_loss_weight=Config.state_loss_weight,
        opponent_loss_weight=Config.opponent_loss_weight,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        condition_guidance_w=Config.condition_guidance_w,
        data_encoder=data_encoder,
        use_inv_dyn=Config.use_inv_dyn,
        use_dynamic_model=Config.use_dynamic_model,
        use_reward_model=Config.use_reward_model,
        use_value_model=Config.use_value_model,
        share_dynamic=Config.share_dynamic,
        joint_dynmic=Config.joint_dynmic,
        joint_reward = Config.joint_reward,
        num_resample = Config.num_resample,
        jump_denoising_step = Config.jump_denoising_step,
        device=Config.device,
    )

    generator_config = utils.Config(
        utils.Generator,
        savepath="generator_config.pkl",
        accept_threshold = Config.accept_threshold,
        generate_batch_size = Config.generate_batch_size,
        generate_episode_nums= Config.generate_episode_nums,
        env_type = Config.env_type,
        generate_device=Config.device,

    )

    model = model_config()
    diffusion = diffusion_config(model)
    generator = generator_config(diffusion, dataset)

    load_path = Config.load_checkpoint_path
    state_dict = torch.load(load_path, map_location=Config.device)
    logger.print(
        f"\nLoaded checkpoint from {load_path} (step {state_dict['step']})\n",
        color="green",
    )
    generator.step = state_dict["step"]
    generator.model.load_state_dict(state_dict["model"])

    for epoch in range(Config.num_generate_epoch):
        logger.print(f"Epoch {epoch} / {Config.num_generate_epoch} | {logger.prefix}")
        generator.generate_episodes()
        print("Generated num/Epoch:", str(len(generator.gen_buffer)) + "/" + str(epoch))
        if epoch > 100:
            if len(generator.gen_buffer) == 0:
                print("Not satisfy accept threshold")
                break    
        if len(generator.gen_buffer) >= Config.generate_episode_nums:
            break

    ### Translate torch to npy ###
    generate_path = Config.generate_path
    os.makedirs(generate_path, exist_ok=True)

    tot_s = torch.stack(generator.gen_buffer.s_buffer, dim=0)
    tot_a = torch.stack(generator.gen_buffer.a_buffer, dim=0)
    tot_r = torch.stack(generator.gen_buffer.r_buffer, dim=0)
    tot_d = torch.stack(generator.gen_buffer.d_buffer, dim=0)

    s = tot_s[:, :-1]
    s_prime = tot_s[:, 1:]

    s = s.reshape(-1, Config.n_agents, tot_s.shape[-1]).detach().cpu().numpy()
    s_prime = s_prime.reshape(-1, Config.n_agents, tot_s.shape[-1]).detach().cpu().numpy()
    a = tot_a.reshape(-1, Config.n_agents, tot_a.shape[-1]).detach().cpu().numpy()
    if Config.joint_reward:
        r = tot_r.reshape(-1, tot_r.shape[-1]).detach().cpu().numpy()
        d = tot_d.reshape(-1, tot_d.shape[-1]).detach().cpu().numpy()
    else:
        r = tot_r.reshape(-1, Config.n_agents, tot_r.shape[-1]).detach().cpu().numpy()
        d = tot_d.reshape(-1, Config.n_agents, tot_d.shape[-1]).detach().cpu().numpy()
    
    
    if Config.env_type == "mamujoco" or Config.env_type == "smac":

        r = np.repeat(r, Config.n_agents, axis=-1)
        discs = torch.ones_like(tot_d.expand(-1, -1, Config.n_agents))
        discs[:, -1, :] = 0
        dics = discs.reshape(-1, Config.n_agents)
        np_discount = dics.detach().cpu().numpy()
        path_length = np.full((len(generator.gen_buffer.s_buffer)), Config.horizon-1)

        np.save(generate_path + "/obs.npy", s)
        np.save(generate_path + "/actions.npy", a)
        np.save(generate_path + "/rewards.npy", r)
        np.save(generate_path + "/discounts.npy", np_discount)
        np.save(generate_path + "/path_lengths.npy", path_length)
        
    else:
        for i in range(Config.n_agents):
            np.save(generate_path + "/obs_" + str(i) + ".npy", s[:, i])
            np.save(generate_path + "/next_obs_" + str(i) + ".npy", s_prime[:, i])
            np.save(generate_path + "/acs_" + str(i) + ".npy", a[:, i])
            if Config.joint_reward:
                np.save(generate_path + "/rews_" + str(i) + ".npy", r.squeeze())
                np.save(generate_path + "/dones_" + str(i) + ".npy", d.squeeze())
            else:
                np.save(generate_path + "/rews_" + str(i) + ".npy", r[:, i].squeeze())
                np.save(generate_path + "/dones_" + str(i) + ".npy", d[:, i].squeeze())


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment", help="experiment specification file")
    parser.add_argument("-g", "--gpu", help="gpu id", type=str, default="0")

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    with open(args.experiment, "r") as spec_file:
        spec_string = spec_file.read()
        exp_specs = yaml.load(spec_string, Loader=yaml.SafeLoader)


    from ml_logger import RUN, logger

    Config = build_config_from_dict(exp_specs)

    Config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    job_name = Config.job_name.format(**vars(Config))
    RUN.prefix, RUN.job_name, _ = RUN(
        script_path=__file__,
        exp_name=exp_specs["exp_name"],
        job_name=job_name + f"/{Config.seed}",
    )

    logger.configure(RUN.prefix, root=RUN.script_root)
    # logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.remove("parameters.pkl")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))
    logger.log_text(
        """
                    charts:
                    - yKey: loss
                      xKey: steps
                    - yKey: a0_loss
                      xKey: steps
                    """,
        filename=".charts.yml",
        dedent=True,
        overwrite=True,
    )
    logger.save_yaml(exp_specs, "exp_specs.yml")

    main(Config, RUN)
