import os
import uuid
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch

import sys
sys.path.append("/home/fn/MyRL/Simple-work2-CDT/OSRL-CDT")
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl.algorithms import CDT, CDTTrainer, CostNet
from osrl.common import SequenceDataset
from osrl.common.exp_util import auto_name, seed_all
import pickle

@pyrallis.wrap()
def train(args: CDTTrainConfig):
    path ="/home/fn/MyRL/Simple-work2-CDT/OSRL-CDT/examples/Mymodel/cdt_cost_mse.pt"
    noise_scale = None
    eval_episodes = 20
    best = False
    device = "cpu"
    threads = 4
    # setup logger
    cfg = asdict(args)
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    # logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # cfg, model = load_config_and_model(args.path, args.best)
        
    if device == "cpu":
        torch.set_num_threads(threads)

    target_entropy = -2

    dataset_path_val = f'/home/fn/MyRL/Process_Expert_Data/idve_val_cdt.pkl'
    dataset_path_train = f'/home/fn/MyRL/Process_Expert_Data/idve_train_cdt.pkl'
    dataset_path_expert = f'/home/fn/MyRL/Process_Expert_Data/idve_expert_data_cdt.pkl'


    with open(dataset_path_train, 'rb') as f:
        data = pickle.load(f)

    with open(dataset_path_val,'rb') as f:
        data_val = pickle.load(f)
        
    with open(dataset_path_expert,'rb') as f:
        data_expert = pickle.load(f)

    cdt_model = CDT(
        state_dim=48,
        action_dim=2,
        max_action=[1,1],
        embedding_dim=128,
        seq_len=10,
        episode_len=300,
        num_layers=3,
        num_heads=8,
        attention_dropout=0.1,
        residual_dropout=0.1,
        embedding_dropout=0.1,
        time_emb=True,
        use_rew=True,
        use_cost=True,
        cost_transform=True,
        add_cost_feat=False,
        mul_cost_feat=False,
        cat_cost_feat=False,
        action_head_layers=1,
        cost_prefix=False,
        stochastic=True,
        init_temperature=0.1,
        target_entropy=target_entropy,
    )
    cdt_model.load_state_dict(torch.load(path))
    state_dim = 48
    action_dim = 2
    cdt_model.to(device)
    costnet = CostNet(state_dim=state_dim,action_dim=action_dim)
    trainer = CDTTrainer(cdt_model,
                        costnet,
                        reward_scale=0.1,
                        cost_scale=1,
                        cost_reverse=False,
                        device=device)
    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)
    dataset = SequenceDataset(
            data,
            seq_len=args.seq_len,
            reward_scale=args.reward_scale,
            cost_scale=args.cost_scale,
            deg=args.deg,
            pf_sample=args.pf_sample,
            max_rew_decrease=args.max_rew_decrease,
            beta=args.beta,
            augment_percent=args.augment_percent,
            cost_reverse=args.cost_reverse,
            max_reward=args.max_reward,
            min_reward=args.min_reward,
            pf_only=args.pf_only,
            rmin=args.rmin,
            cost_bins=args.cost_bins,
            npb=args.npb,
            cost_sample=args.cost_sample,
            cost_transform=ct,
            start_sampling=args.start_sampling,
            prob=args.prob,
            random_aug=args.random_aug,
            aug_rmin=args.aug_rmin,
            aug_rmax=args.aug_rmax,
            aug_cmin=args.aug_cmin,
            aug_cmax=args.aug_cmax,
            cgap=args.cgap,
            rstd=args.rstd,
            cstd=args.cstd,
        )

    trainloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=args.num_workers,
        )
    trainloader_iter = iter(trainloader)

        # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    maxcost = 0
    maxreturn =0
    for i in range(len(dataset.dataset)):
        path = dataset.dataset[i]
        maxcost = max(maxcost,path['cost_returns'][0])
        maxreturn = max(maxreturn,path['returns'][0])
    print(maxcost,maxreturn)
    for step in trange(10000, desc="Training"):
        # batch = next(trainloader_iter)
        # states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
        #         b.to(args.device) for b in batch
        # ]
        #cost_return,episode_cost,costs = trainer.getcost(states,actions)
        expert_s,expert_a,start_s,length = trainer.get_expert(data_expert,num=500)
        optimal_s,optimal_a = trainer.get_optimal(start_s,length)  # 产生 state_s开头的轨迹，length长度的轨迹

        trainer.train_cost_net(expert_s,expert_a,optimal_s,optimal_a)

if __name__ == "__main__":
    train()