import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import CDT, CDTTrainer, CDT_with_action_AE
from osrl.common import SequenceDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model
from osrl.common.dataset import process_bc_dataset


@pyrallis.wrap()
def train(args: CDTTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(CDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.use_sa_encoder and args.state_encoder_path is not None:
        args.name = args.name+"_with_sae"
        if "idm_loss_weight0.0" in args.state_encoder_path:
            args.name = args.name+"_idm0"
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)
        
    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    if "Metadrive" in args.task:
        # import gym
        env = gym_org.make(args.task)
    else:
        env = gym.make(args.task)

    # pre-process offline dataset
    data = env.get_dataset()
    env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    data = env.pre_process_data(data,
                                args.outliers_percent,
                                args.noise_scale,
                                args.inpaint_ranges,
                                args.epsilon,
                                args.density,
                                cbins=cbins,
                                rbins=rbins,
                                max_npb=max_npb,
                                min_npb=min_npb)
    
    if args.safe_only:
        # process_bc_dataset(data, 0, 1.0, "safe")
        # new_data =  {}
        idx = (data["costs"]==0) | (data["terminals"]) | (data["timeouts"])
        for key in data.keys():
            data[key] = data[key][idx]
    elif args.cost_relable:
        from osrl.common.function import env2cost_dict
        get_cost = env2cost_dict[args.task]
        for i in range(data["costs"].shape[0]):
            data["costs"][i] = get_cost(data["observations"][i])

    # wrapper
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)

    state_encoder = None
    action_encoder = None
    action_encoder_logger = None
    train_action_encoder = False
    if args.use_sa_encoder:
        if args.state_encoder_path is not None:
            senc_cfg, senc_model = load_config_and_model(args.state_encoder_path, True)
            state_encoder = State_AE(
                state_dim=env.observation_space.shape[0],
                encode_dim=senc_cfg["state_encode_dim"],
                hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
            )
            state_encoder.load_state_dict(senc_model["model_state"])
            # state_encoder.to(args.device)
            state_encoder.eval()

        if args.action_encoder_path is not None:
            train_action_encoder = args.train_action_encoder
            aenc_cfg, aenc_model = load_config_and_model(args.action_encoder_path, True)
            action_encoder = Action_AE(
                action_dim=env.action_space.shape[0],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            if args.pretrained_initialize:
                action_encoder.load_state_dict(aenc_model["model_state"])
            if train_action_encoder:
                action_encoder.to(args.device)
                final_path=args.action_encoder_path.rfind('/')
                new_name="action_encoder_after_pretrain"
                action_encoder_logdir=args.action_encoder_path[:final_path]+"/"+new_name
                action_encoder_logger = TensorboardLogger(action_encoder_logdir, log_txt=True, name=new_name)
            else:
                action_encoder.eval()

    # model & optimizer & scheduler setup
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    if state_encoder is not None:
        state_dim = senc_cfg["state_encode_dim"]
    if action_encoder is not None:
        action_dim = aenc_cfg["action_encode_dim"]
    cdt_model = CDT(
        state_dim=state_dim,
        action_dim=action_dim,
        max_action=env.action_space.high[0],
        embedding_dim=args.embedding_dim,
        seq_len=args.seq_len,
        episode_len=args.episode_len,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        attention_dropout=args.attention_dropout,
        residual_dropout=args.residual_dropout,
        embedding_dropout=args.embedding_dropout,
        time_emb=args.time_emb,
        use_rew=args.use_rew,
        use_cost=args.use_cost,
        cost_transform=args.cost_transform,
        add_cost_feat=args.add_cost_feat,
        mul_cost_feat=args.mul_cost_feat,
        cat_cost_feat=args.cat_cost_feat,
        action_head_layers=args.action_head_layers,
        cost_prefix=args.cost_prefix,
        stochastic=args.stochastic,
        init_temperature=args.init_temperature,
        target_entropy=-env.action_space.shape[0],
    ).to(args.device)
    

    if train_action_encoder:
        model = CDT_with_action_AE(cdt_model, action_encoder, env.action_space.shape[0])
        def checkpoint_fn_action():
            return {"model_state": action_encoder.state_dict()}
        def checkpoint_fn():
            return {"model_state": model.cdt.state_dict()}
        action_encoder_logger.setup_checkpoint_fn(checkpoint_fn_action)
    else:
        model = cdt_model
        def checkpoint_fn():
            return {"model_state": model.state_dict()}
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = CDTTrainer(model,
                         env,
                         logger=logger,
                         learning_rate=args.learning_rate,
                         weight_decay=args.weight_decay,
                         betas=args.betas,
                         clip_grad=args.clip_grad,
                         lr_warmup_steps=args.lr_warmup_steps,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         loss_cost_weight=args.loss_cost_weight,
                         loss_state_weight=args.loss_state_weight,
                         cost_reverse=args.cost_reverse,
                         no_entropy=args.no_entropy,
                         device=args.device,
                         state_encoder=state_encoder,
                         action_encoder=action_encoder,
                         train_action_encoder=train_action_encoder)

    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)

    if args.safe_only:
        args.augment_percent = 0

    dataset = SequenceDataset(
        data,
        seq_len=args.seq_len,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        deg=args.deg,
        pf_sample=args.pf_sample,
        max_rew_decrease=args.max_rew_decrease,
        beta=args.beta,
        augment_percent=args.augment_percent,
        cost_reverse=args.cost_reverse,
        max_reward=args.max_reward,
        min_reward=args.min_reward,
        pf_only=args.pf_only,
        rmin=args.rmin,
        cost_bins=args.cost_bins,
        npb=args.npb,
        cost_sample=args.cost_sample,
        cost_transform=ct,
        start_sampling=args.start_sampling,
        prob=args.prob,
        random_aug=args.random_aug,
        aug_rmin=args.aug_rmin,
        aug_rmax=args.aug_rmax,
        aug_cmin=args.aug_cmin,
        aug_cmax=args.aug_cmax,
        cgap=args.cgap,
        rstd=args.rstd,
        cstd=args.cstd,
        state_encoder=state_encoder,
        action_encoder=action_encoder,
        train_action_encoder=train_action_encoder,
        device=args.device
    )
    

    trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    trainloader_iter = iter(trainloader)

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    for step in trange(args.update_steps, desc="Training"):
        batch = next(trainloader_iter)
        
        states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
            b.to(args.device) for b in batch
        ]
        # print(states.shape, actions.shape, returns.shape, costs_return.shape, time_steps.shape, mask.shape, episode_cost.shape, costs.shape)
        
        trainer.train_one_step(states, actions, returns, costs_return, time_steps, mask,
                               episode_cost, costs)
    

        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            average_reward, average_cost = [], []
            log_cost, log_reward, log_len = {}, {}, {}
            if args.safe_only:
                target_returns = [args.target_returns[0]]
            else:
                target_returns = args.target_returns
            for target_return in target_returns:
                reward_return, cost_return = target_return
                if args.safe_only:
                    cost_return = 0
                if args.cost_reverse:
                    # critical step, rescale the return!
                    ret, cost, length = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        (args.episode_len - cost_return) * args.cost_scale)
                    ret_ctg_pos, cost_ctg_pos, length_ctg_pos = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        (args.episode_len - cost_return) * args.cost_scale, keep_ctg_positive=True)
                else:
                    ret, cost, length = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        cost_return * args.cost_scale)
                    ret_ctg_pos, cost_ctg_pos, length_ctg_pos = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        cost_return * args.cost_scale, keep_ctg_positive=True)
                average_cost.append(cost)
                average_reward.append(ret)

                name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
                log_cost.update({name: cost})
                log_reward.update({name: ret})
                log_len.update({name: length})

                name_pos = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))+"_ctg_pos"
                log_cost.update({name_pos: cost_ctg_pos})
                log_reward.update({name_pos: ret_ctg_pos})
                log_len.update({name_pos: length_ctg_pos})


            logger.store(tab="cost", **log_cost)
            logger.store(tab="ret", **log_reward)
            logger.store(tab="length", **log_len)

            # save the current weight
            logger.save_checkpoint()
            if train_action_encoder:
                action_encoder_logger.save_checkpoint()
            # save the best weight
            mean_ret = np.mean(average_reward)
            mean_cost = np.mean(average_cost)
            if mean_cost < best_cost or (mean_cost == best_cost
                                         and mean_ret > best_reward):
                best_cost = mean_cost
                best_reward = mean_ret
                best_idx = step
                logger.save_checkpoint(suffix="best")
                if train_action_encoder:
                    action_encoder_logger.save_checkpoint(suffix="best")

            logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)


if __name__ == "__main__":
    train()
