import os
import uuid
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch

import sys
sys.path.append("/home/fn/MyRL/Simple-work2-CDT/OSRL-CDT")
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples_cdt.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl_cdt.algorithms import CDT, CDTTrainer
from osrl_cdt.common import SequenceDataset
from osrl_cdt.common.exp_util import auto_name, seed_all
import pickle


def get_val_path(path,batch_size=64):
    num_trajectories = len(path['rewards'])
    batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
        )
    paths = []

    actions = []
    costs = []
    next_observations = []
    observations = []
    rewards = []
    terminals = []
    dieds = []

    for i in range(batch_size):
        actions.append(path['actions'][batch_inds[i]])
        next_observations.append(path['next_observations'][batch_inds[i]])
        observations.append(path['observations'][batch_inds[i]])
        terminals.append(path['terminals'][batch_inds[i]])
        rewards.append(path['rewards'][batch_inds[i]])
        dieds.append(path['dieds'][batch_inds[i]])
        costs.append(path['costs'][batch_inds[i]])


    paths = dict({'actions': np.array(actions),'next_observations': 
                 np.array(next_observations),'observations': np.array(observations),
                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),
                 'dieds':np.array(dieds)})
        
    return paths

@pyrallis.wrap()
def train(args: CDTTrainConfig):
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # setup logger
    cfg = asdict(args)
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    #logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)


    # pre-process offline dataset
    dataset_path_val = f'/home/fn/MyRL/DT-ICRL/data/sepsis_data/val_policy_data.pkl'
    dataset_path = f'/home/fn/MyRL/DT-ICRL/data/sepsis_data/train_policy_data.pkl'
    

    with open(dataset_path, 'rb') as f:
        data = pickle.load(f)

    with open(dataset_path_val,'rb') as f:
        data_val = pickle.load(f)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]

    
    state_dim = 48
    action_dim = 2
    max_action = [1.0,1.0]

    #costnet = CostNet(state_dim=state_dim,action_dim=action_dim)

    model = CDT(
        state_dim=state_dim,
        action_dim=action_dim,
        max_action=max_action,
        embedding_dim=args.embedding_dim,
        seq_len=args.seq_len,
        episode_len=args.episode_len,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        attention_dropout=args.attention_dropout,
        residual_dropout=args.residual_dropout,
        embedding_dropout=args.embedding_dropout,
        time_emb=args.time_emb,
        use_rew=args.use_rew,
        use_cost=args.use_cost,
        cost_transform=args.cost_transform,
        add_cost_feat=args.add_cost_feat,
        mul_cost_feat=args.mul_cost_feat,
        cat_cost_feat=args.cat_cost_feat,
        action_head_layers=args.action_head_layers,
        cost_prefix=args.cost_prefix,
        stochastic=args.stochastic,
        init_temperature=args.init_temperature,
        target_entropy=-action_dim,
    ).to(args.device)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)




    # trainer
    trainer = CDTTrainer(model,
                         logger=logger,
                         learning_rate=args.learning_rate,
                         weight_decay=args.weight_decay,
                         betas=args.betas,
                         clip_grad=args.clip_grad,
                         lr_warmup_steps=args.lr_warmup_steps,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         loss_cost_weight=args.loss_cost_weight,
                         loss_state_weight=args.loss_state_weight,
                         cost_reverse=args.cost_reverse,
                         no_entropy=args.no_entropy,
                         device=args.device)

    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)  #用于求采样权重，优先采cost小的数据

    dataset = SequenceDataset(
        data,
        seq_len=args.seq_len,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        deg=args.deg,
        pf_sample=args.pf_sample,
        max_rew_decrease=args.max_rew_decrease,
        beta=args.beta,
        augment_percent=args.augment_percent,
        cost_reverse=args.cost_reverse,
        max_reward=args.max_reward,
        min_reward=args.min_reward,
        pf_only=args.pf_only,
        rmin=args.rmin,
        cost_bins=args.cost_bins,
        npb=args.npb,
        cost_sample=args.cost_sample,
        cost_transform=ct,
        start_sampling=args.start_sampling,
        prob=args.prob,
        random_aug=args.random_aug,
        aug_rmin=args.aug_rmin,
        aug_rmax=args.aug_rmax,
        aug_cmin=args.aug_cmin,
        aug_cmax=args.aug_cmax,
        cgap=args.cgap,
        rstd=args.rstd,
        cstd=args.cstd,
    )

    trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    trainloader_iter = iter(trainloader)

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    maxcost = 0
    maxreturn =0
    for i in range(len(dataset.dataset)):
        path = dataset.dataset[i]
        maxcost = max(maxcost,path['cost_returns'][0])
        maxreturn = max(maxreturn,path['returns'][0])
    print(maxcost,maxreturn)

    for step in trange(args.update_steps, desc="Training"):
        batch = next(trainloader_iter)
        states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
            b.to(args.device) for b in batch
        ]
        
        trainer.train_one_step(states, actions, returns, costs_return, time_steps, mask,
                               episode_cost, costs)
        
        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            average_reward, average_cost = [], []
            log_cost, log_reward, log_len = {}, {}, {}
            for target_return in args.target_returns:
                reward_return, cost_return = target_return
                val_df = get_val_path(data_val)
                if args.cost_reverse:
                    # critical step, rescale the return!
                    ret, cost, length = trainer.evaluate(val_df,
                        args.eval_episodes, reward_return * args.reward_scale,
                        (args.episode_len - cost_return) * args.cost_scale)
                else:
                    ret, cost, length = trainer.evaluate(val_df,
                        args.eval_episodes, reward_return * args.reward_scale,
                        cost_return * args.cost_scale)
                average_cost.append(cost)
                average_reward.append(ret)

                name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
                log_cost.update({name: cost})   # phy action
                log_reward.update({name: ret})  # agent action
                log_len.update({name: length}) # mse

            logger.store(tab="cost", **log_cost)
            logger.store(tab="ret", **log_reward)
            logger.store(tab="length", **log_len)

            # save the current weight
            logger.save_checkpoint()
            # save the best weight
            mean_ret = np.mean(average_reward)
            mean_cost = np.mean(average_cost)
            if mean_cost < best_cost or (mean_cost == best_cost
                                         and mean_ret > best_reward):
                best_cost = mean_cost
                best_reward = mean_ret
                best_idx = step
                logger.save_checkpoint(suffix="best")

            logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)
    #/home/fn/MyRL/Simple-work2-CDT/OSRL-CDT/examples_cdt/Mymodel
    path_ = f'/home/fn/MyRL/Simple-work2-CDT/OSRL-CDT/examples_cdt/Mymodel/cdt_1128_2.pt' #f'/home/fn/OSRL/examples/Mymodel/cdt_0904_q1_5.pt'
    torch.save(model.state_dict(),path_)


if __name__ == "__main__":
    train()
