import os
import uuid
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
import sys
sys.path.append("../A-Sepsis-RL/CDT")
from examples_my_cost.configs.bearl_configs import BEARL_DEFAULT_CONFIG, BEARLTrainConfig
from osrl_my_cost.algorithms import BEARL, BEARLTrainer
from osrl_my_cost.common import TransitionDataset
from osrl_my_cost.common.exp_util import auto_name, seed_all
import pickle

@pyrallis.wrap()
def train(args: BEARLTrainConfig):
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # setup logger
    cfg = asdict(args)
    default_cfg = asdict(BEARL_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    # logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # # initialize environment
    # env = gym.make(args.task)

    # # pre-process offline dataset
    # data = env.get_dataset()
    # env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    # data = env.pre_process_data(data,
    #                             args.outliers_percent,
    #                             args.noise_scale,
    #                             args.inpaint_ranges,
    #                             args.epsilon,
    #                             args.density,
    #                             cbins=cbins,
    #                             rbins=rbins,
    #                             max_npb=max_npb,
    #                             min_npb=min_npb)

    # # wrapper
    # env = wrap_env(
    #     env=env,
    #     reward_scale=args.reward_scale,
    # )
    # env = OfflineEnvWrapper(env)
    
    # dataset_path_val = f'./CDT/examples_my_cost/data/my_cdt_data_val_noauto_c6.pkl'
    # dataset_path = f'./CDT/examples_my_cost/data/my_cdt_data_train_noauto_c6.pkl'


    # dataset_path_val = f'/home/fn/A-Sepsis-RL/CDT/examples_my_cost/data/transicrl_data_val.pkl'
    # dataset_path = f'/home/fn/A-Sepsis-RL/CDT/examples_my_cost/data/transicrl_data.pkl'

    dataset_path_val = f'./CDT/examples_my_cost/data/cost0_data_val.pkl'
    dataset_path = f'./CDT/examples_my_cost/data/cost0_data_val.pkl'
    
    with open(dataset_path, 'rb') as f:
        data = pickle.load(f)

    with open(dataset_path_val,'rb') as f:
        data_val = pickle.load(f)
    
    state_dim = 48
    action_dim = 2
    max_action = torch.tensor([1.0,1.0],dtype=torch.float32)

    # model & optimizer setup
    model = BEARL(
        state_dim=state_dim,
        action_dim=action_dim,
        max_action=max_action,
        a_hidden_sizes=args.a_hidden_sizes,
        c_hidden_sizes=args.c_hidden_sizes,
        vae_hidden_sizes=args.vae_hidden_sizes,
        sample_action_num=args.sample_action_num,
        gamma=args.gamma,
        tau=args.tau,
        beta=args.beta,
        lmbda=args.lmbda,
        mmd_sigma=args.mmd_sigma,
        target_mmd_thresh=args.target_mmd_thresh,
        start_update_policy_step=args.start_update_policy_step,
        num_q=args.num_q,
        num_qc=args.num_qc,
        PID=args.PID,
        cost_limit=args.cost_limit,
        episode_len=args.episode_len,
        device=args.device,
    )
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = BEARLTrainer(model,
                           logger=logger,
                           actor_lr=args.actor_lr,
                           critic_lr=args.critic_lr,
                           vae_lr=args.vae_lr,
                           reward_scale=args.reward_scale,
                           cost_scale=args.cost_scale,
                           device=args.device)

    # initialize pytorch dataloader
    dataset = TransitionDataset(data,
                                reward_scale=args.reward_scale,
                                cost_scale=args.cost_scale)
    trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    trainloader_iter = iter(trainloader)

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    # training
    for step in trange(args.update_steps, desc="Training"):
        batch = next(trainloader_iter)
        observations, next_observations, actions, rewards, costs, done = [
            b.to(args.device) for b in batch
        ]
        trainer.train_one_step(observations.float(),
                                next_observations.float(), 
                                actions.float(), rewards.float(), 
                                costs.float(),
                               done)

        # # evaluation
        # if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
        #     ret, cost, length = trainer.evaluate(args.eval_episodes)
        #     logger.store(tab="eval", Cost=cost, Reward=ret, Length=length)

        #     # save the current weight
        #     logger.save_checkpoint()
        #     # save the best weight
        #     if cost < best_cost or (cost == best_cost and ret > best_reward):
        #         best_cost = cost
        #         best_reward = ret
        #         best_idx = step
        #         logger.save_checkpoint(suffix="best")

        #     logger.store(tab="train", best_idx=best_idx)
        #     logger.write(step, display=False)

        # else:
        logger.write_without_reset(step)
    path_ =f'/home/fn/A-Sepsis-RL/CDT/Mymodel/bearl_cost0.pt'
    print("bearl_cost0_ success")
    torch.save(model.state_dict(),path_)


if __name__ == "__main__":
    train()
