import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger, TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.offline_model_configs import OfflineModelTrainConfig, OfflineModel_DEFAULT_CONFIG
from osrl.algorithms import EnsembleDynamics, EnsembleDynamicsModel, EnsembleCostModel
from osrl.common import TransitionDataset
from osrl.common.exp_util import auto_name, seed_all
from osrl.common.model_logger import Logger
from osrl.common.net import StandardScaler, SimpleScaler, termination_fn_common
from osrl.common.point_robot import PointRobot
import h5py



@pyrallis.wrap()
def train(args: OfflineModelTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(OfflineModelTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(OfflineModel_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(OfflineModel_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "dynamics_training_progress": "csv",
        "tb": "tensorboard"
    }
    model_logger =Logger(logger.log_dir, output_config)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    # if "Metadrive" in args.task:
    #     # import gym
    #     env = gym_org.make(args.task)
    # else:
    #     env = gym.make(args.task)

    env = PointRobot(id=0, seed=0)

    # pre-process offline dataset
    data_location = None
    data = {}
    f = h5py.File(data_location, 'r')
    data["observations"] = np.array(f['state'])
    data["actions"] = np.array(f['action'])
    data["next_observations"] = np.array(f['next_state'])
    data["rewards"] = np.array(f['reward'])
    data["dones"] = np.array(f['done'])
    data['costs'] = np.array(f['cost'])

    if args.safe_only:
        idx = (data["costs"]==0)
        for key in data.keys():
            data[key] = data[key][idx]
    # print(idx.shape)
    # print(data["observations"].shape)
    
    # assert False
    # 记得恢复 logger

    # wrapper

    # model & optimizer setup
    dynamics_model = EnsembleDynamicsModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=args.dynamic_hidden_dims,
        num_ensemble=args.num_ensemble,
        num_elites=args.num_elites,
        weight_decays=args.dynamic_weight_decays,
        with_cost=args.with_cost,
        device=args.device
    )
    cost_model = EnsembleCostModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=args.cost_model_hidden_dims,
        num_ensemble=args.num_ensemble,
        num_elites=args.num_elites,
        weight_decays=args.dynamic_weight_decays,
        device=args.device
    )
    print(f"Total parameters: {sum(p.numel() for p in dynamics_model.parameters())}")
    dynamics_optim = torch.optim.Adam(
        dynamics_model.parameters(),
        lr=args.learning_rate
    )
    cost_model_optim = torch.optim.Adam(
        cost_model.parameters(),
        lr=args.learning_rate
    )
    dynamics_scheduler = torch.optim.lr_scheduler.StepLR(dynamics_optim, step_size=args.decay_step, gamma=args.decay_rate)
    cost_model_scheduler = torch.optim.lr_scheduler.StepLR(cost_model_optim, step_size=args.decay_step, gamma=args.decay_rate)
    if args.simple_scaler:
        scaler = SimpleScaler()
    else:
         scaler = StandardScaler()
    termination_fn = termination_fn_common
    if args.safe_only:
        cost_model = None
        cost_model_optim = None
        cost_model_scheduler = None
    dynamics = EnsembleDynamics(
        dynamics_model,
        cost_model,
        dynamics_optim,
        cost_model_optim,
        scaler,
        termination_fn,
        use_scheduler=args.use_scheduler,
        dynamics_scheduler=dynamics_scheduler,
        cost_model_scheduler=cost_model_scheduler,
        penalty_coef=args.penalty_coef,
        with_cost=args.with_cost,
        use_delta_obs=args.use_delta_obs,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        cost_coef=args.cost_coef
    )
    dynamics.train(data, model_logger, batch_size=args.batch_size, max_epochs_since_update=args.max_epochs_since_update)


if __name__ == "__main__":
    train()
