import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.bc_configs import BC_DEFAULT_CONFIG, BCTrainConfig
from osrl.algorithms import BC, BCTrainer
from osrl.common import TransitionDataset
from osrl.common.dataset import process_bc_dataset, process_realworld_dataset, process_realworld2_dataset
from osrl.common.exp_util import auto_name, seed_all

import wandb
wandb.login(key='3d0a4921403bb5233bec6e1d9d55dcb01e30bfd2')

@pyrallis.wrap()
def train(args: BCTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(BCTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(BC_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(BC_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    # logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # the cost scale is down in trainer rollout
    import gymnasium as gym  # noqa
    if "Metadrive" in args.task:
        import gym
    env = gym.make(args.task)
    while True:
        success_load = True
        try:
            data = env.get_dataset()
        except:
            print('Fail to load data... One time')
            success_load = False
        if success_load:
            break
    env.set_target_cost(args.cost_limit)

    # cbins, rbins, max_npb, min_npb = None, None, None, None
    # if args.density != 1.0:
    #     density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
    #     cbins = density_cfg["cbins"]
    #     rbins = density_cfg["rbins"]
    #     max_npb = density_cfg["max_npb"]
    #     min_npb = density_cfg["min_npb"]
    
    # Below is out-lated data
    # # Load data and process data
    # data = env.pre_process_data(data,
    #                             args.outliers_percent,
    #                             args.noise_scale,
    #                             args.inpaint_ranges,
    #                             args.epsilon,
    #                             args.density,
    #                             cbins=cbins,
    #                             rbins=rbins,
    #                             max_npb=max_npb,
    #                             min_npb=min_npb)

    # # In-place edit
    # process_bc_dataset(data, args.cost_limit, args.gamma, args.bc_mode)

    # optimal_data, qualified_data = process_realworld2_dataset(data, args.cost_limit, args.gamma, args.task)
    optimal_data, qualified_data = process_realworld_dataset(data, args.cost_limit, args.gamma, args.frontier_ratio, args.task)
    if args.bc_mode == "frontier":
        data = optimal_data
    elif args.bc_mode == "all":
        data = {
            _k: np.concatenate([optimal_data[_k], qualified_data[_k]], axis=0) for _k in optimal_data
        }
    elif args.bc_mode == "augmentation":
        for _key in optimal_data:
            print(_key, optimal_data[_key].shape)
        # loaded_data = np.load('generated_data/generated_data_predicted.npz')
        loaded_data = np.load('generated_data/generated_data_stopredicted.npz')
        # generated_data = {
        #     _key: loaded_data[_key][:50] for _key in loaded_data
        # }
        generated_data = {
            _key: loaded_data[_key] for _key in loaded_data
        }
        _shape = generated_data["observations"].shape[:2]
        generated_data["costs"] = np.zeros((*_shape,))
        generated_data["next_observations"] = np.zeros(generated_data["observations"].shape)
        generated_data["rewards"] = np.zeros((*_shape,))
        generated_data["terminals"] = np.zeros((*_shape,))
        generated_data["timeouts"] = np.zeros((*_shape,))
        generated_data["timeouts"][:, -1] = 1
        for _key in generated_data:
            generated_data[_key] = generated_data[_key].reshape((_shape[0]*_shape[1], *generated_data[_key].shape[2:]))
        # N_REPEAT = 20
        N_REPEAT = 10
        data = {
            _k: np.concatenate([optimal_data[_k] for _ in range(N_REPEAT)] + [generated_data[_k]], axis=0) for _k in optimal_data
        }
    
    # see env information
    print('===== env information =====')
    print('min episode reward', env.min_episode_reward)
    print('max episode reward', env.max_episode_reward)
    
    # model & optimizer & scheduler setup
    state_dim = env.observation_space.shape[0]
    if args.bc_mode == "multi-task":
        state_dim += 1
    model = BC(
        state_dim=state_dim,
        action_dim=env.action_space.shape[0],
        max_action=env.action_space.high[0],
        a_hidden_sizes=args.a_hidden_sizes,
        episode_len=args.episode_len,
        device=args.device,
    )
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    trainer = BCTrainer(model,
                        env,
                        logger=logger,
                        actor_lr=args.actor_lr,
                        bc_mode=args.bc_mode,
                        cost_limit=args.cost_limit,
                        device=args.device)

    trainloader = DataLoader(
        TransitionDataset(data),
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    trainloader_iter = iter(trainloader)

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    for step in trange(args.update_steps, desc="Training"):
        batch = next(trainloader_iter)
        # observations, next_observations, actions, rewards, costs, done
        # Note: Costs are ban, not able to utilize
        observations, _, actions, _, _, _ = [b.to(args.device) for b in batch]
        trainer.train_one_step(observations, actions)

        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            ret, cost, length = trainer.evaluate(args.eval_episodes)
            logger.store(tab="eval", Cost=cost, Reward=ret, Length=length)

            # save the current weight
            logger.save_checkpoint()
            # save the best weight
            if cost < best_cost or (cost == best_cost and ret > best_reward):
                best_cost = cost
                best_reward = ret
                best_idx = step
                logger.save_checkpoint(suffix="best")

            logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)
    wandb.finish()

if __name__ == "__main__":
    train()
