import os
import uuid
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import numpy as np

import torch

import sys
sys.path.append("/home/fn/MyRL/Our-Model")
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from CDT.examples_cdt.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from CDT.osrl_cdt.algorithms import CDT, CDTTrainer
from CDT.osrl_cdt.common import SequenceDataset
from CDT.osrl_cdt.common.exp_util import auto_name, seed_all
import pickle
from ICRL.evaluation.evaluate_episodes import evaluation_cost_tocdt
def get_val_path(path,batch_size=64):
    num_trajectories = len(path['rewards'])
    batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
        )
    paths = []

    actions = []
    costs = []
    next_observations = []
    observations = []
    rewards = []
    terminals = []
    dieds = []

    for i in range(batch_size):
        actions.append(path['actions'][batch_inds[i]])
        next_observations.append(path['next_observations'][batch_inds[i]])
        observations.append(path['observations'][batch_inds[i]])
        terminals.append(path['terminals'][batch_inds[i]])
        rewards.append(path['rewards'][batch_inds[i]])
        dieds.append(path['dieds'][batch_inds[i]])
        costs.append(path['costs'][batch_inds[i]])


    paths = dict({'actions': np.array(actions),'next_observations': 
                 np.array(next_observations),'observations': np.array(observations),
                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),
                 'dieds':np.array(dieds)})
        
    return paths

class train_cdt:
    def __init__(self,data,data_val,variant,wandb,vent):
        self.args = CDTTrainConfig
        seed_all(self.args.seed)
        if self.args.device == "cpu":
            torch.set_num_threads(self.args.threads)

        self.wandb = wandb
        self.device = self.args.device

        # # pre-process offline dataset
        # self.dataset_path_val = f'/home/fn/MyRL/DT-ICRL/data/sepsis_data/val_policy_data.pkl'
        # self.dataset_path = f'/home/fn/MyRL/DT-ICRL/data/sepsis_data/train_policy_data.pkl'
        

        # with open(self.dataset_path, 'rb') as f:
        #     self.data = pickle.load(f)

        # with open(self.dataset_path_val,'rb') as f:
        #     self.data_val = pickle.load(f)
        self.data = data
        self.data_val = data_val
        self.vent = vent
        
        self.state_dim = variant['state_dim']
        self.act_dim = variant['act_dim']
        self.max_action = variant['act_max']

        #costnet = CostNet(state_dim=state_dim,action_dim=action_dim)

        self.model = CDT(
            state_dim=self.state_dim,
            action_dim=self.act_dim,
            max_action=self.max_action,
            embedding_dim=self.args.embedding_dim,
            seq_len=self.args.seq_len,
            episode_len=self.args.episode_len,
            num_layers=self.args.num_layers,
            num_heads=self.args.num_heads,
            attention_dropout=self.args.attention_dropout,
            residual_dropout=self.args.residual_dropout,
            embedding_dropout=self.args.embedding_dropout,
            time_emb=self.args.time_emb,
            use_rew=self.args.use_rew,
            use_cost=self.args.use_cost,
            cost_transform=self.args.cost_transform,
            add_cost_feat=self.args.add_cost_feat,
            mul_cost_feat=self.args.mul_cost_feat,
            cat_cost_feat=self.args.cat_cost_feat,
            action_head_layers=self.args.action_head_layers,
            cost_prefix=self.args.cost_prefix,
            stochastic=self.args.stochastic,
            init_temperature=self.args.init_temperature,
            target_entropy=-self.act_dim,
            vent = self.vent
        ).to(self.args.device)
        #print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())}")

        

        #self.logger.setup_checkpoint_fn(self.checkpoint_fn)




        # trainer
        self.trainer = CDTTrainer(self.model,
                            learning_rate=self.args.learning_rate,
                            weight_decay=self.args.weight_decay,
                            betas=self.args.betas,
                            clip_grad=self.args.clip_grad,
                            lr_warmup_steps=self.args.lr_warmup_steps,
                            reward_scale=self.args.reward_scale,
                            cost_scale=self.args.cost_scale,
                            loss_cost_weight=self.args.loss_cost_weight,
                            loss_state_weight=self.args.loss_state_weight,
                            cost_reverse=self.args.cost_reverse,
                            no_entropy=self.args.no_entropy,
                            device=self.args.device,
                            vent = self.vent)

        self.ct = lambda x: 70 - x if self.args.linear else 1 / (x + 10)  #用于求采样权重，优先采cost小的数据

        self.dataset = SequenceDataset(
            self.data,
            seq_len=self.args.seq_len,
            reward_scale=self.args.reward_scale,
            cost_scale=self.args.cost_scale,
            deg=self.args.deg,
            pf_sample=self.args.pf_sample,
            max_rew_decrease=self.args.max_rew_decrease,
            beta=self.args.beta,
            augment_percent=self.args.augment_percent,
            cost_reverse=self.args.cost_reverse,
            max_reward=self.args.max_reward,
            min_reward=self.args.min_reward,
            pf_only=self.args.pf_only,
            rmin=self.args.rmin,
            cost_bins=self.args.cost_bins,
            npb=self.args.npb,
            cost_sample=self.args.cost_sample,
            cost_transform=self.ct,
            start_sampling=self.args.start_sampling,
            prob=self.args.prob,
            random_aug=self.args.random_aug,
            aug_rmin=self.args.aug_rmin,
            aug_rmax=self.args.aug_rmax,
            aug_cmin=self.args.aug_cmin,
            aug_cmax=self.args.aug_cmax,
            cgap=self.args.cgap,
            rstd=self.args.rstd,
            cstd=self.args.cstd,
        )

        self.trainloader = DataLoader(
            self.dataset,
            batch_size=self.args.batch_size,
            pin_memory=True,
            num_workers=self.args.num_workers,
        )
        self.trainloader_iter = iter(self.trainloader)

        # for saving the best
        self.best_reward = -np.inf
        self.best_cost = np.inf
        self.best_idx = 0

        self.maxcost = 0
        self.maxreturn =0
        for i in range(len(self.dataset.dataset)):
            path = self.dataset.dataset[i]
            self.maxcost = max(self.maxcost,path['cost_returns'][0])
            self.maxreturn = max(self.maxreturn,path['returns'][0])
        print(self.maxcost,self.maxreturn)
    # update cost

    def get_no_mask_tolist(self,data,mask):
        first_one_index = torch.argmax(mask, dim=1)
        a_masked = [row[index:].view(-1).tolist() for row, index in zip(data, first_one_index)]
        flattened_list = [item for sublist in a_masked for item in sublist]
        return np.array(flattened_list)
    
    def discounted_cumsum(self,x: np.ndarray, gamma: float) -> np.ndarray:
        """
        Calculate the discounted cumulative sum of x (can be rewards or costs).
        """
        cumsum = np.zeros_like(x)
        cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            cumsum[t] = x[t] + gamma * cumsum[t + 1]
        return cumsum

    def get_one_path(self,icrl,traj):
        l = len(traj['observations'])
        s, a, timesteps, mask,d = [],[], [], [],[]
        max_len = icrl.K
        s.append(traj['observations'].reshape(1, -1, self.state_dim).astype(np.float64))

        a.append(traj['actions'].reshape(1, -1, self.act_dim))
            
        d.append(np.array([0]*(l-1) + [1]).reshape(1,-1))
            
        timesteps.append(np.arange(0, s[-1].shape[1]).reshape(1, -1))
        timesteps[-1][timesteps[-1] >= icrl.max_ep_len] = icrl.max_ep_len-1  # padding cutoff

                # padding and state + reward normalization
        tlen = s[-1].shape[1]
        if max_len > tlen:
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, self.state_dim)), s[-1]], axis=1)
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, self.act_dim)) * -10., a[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))
        else:
            mask.append(np.concatenate([np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.args.device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.args.device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=self.args.device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.args.device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.args.device)
        return s,a,d,timesteps,mask
        
    def update_train_cost(self,icrl,update_batch=0.4): # 更新40%的数据
        l = len(self.trainloader.dataset.dataset)
        update_len = int(update_batch*l)
        id = np.random.choice(l, update_len, replace=False)
        
        for i in range(update_len):
            path = self.dataset.dataset[id[i]]
            s,a,d,timesteps,mask = self.get_one_path(icrl,path)
            sum_pred_expert = evaluation_cost_tocdt(icrl.model,use_weighted_sum=True,
                                                train_type='every',
                                                states=s,
                                                actions=a,
                                                timesteps=timesteps,
                                                attention_mask=mask,)
            #print(self.dataset.dataset[id[i]]['costs'])
            self.dataset.dataset[id[i]]['costs'] = self.get_no_mask_tolist(sum_pred_expert,mask)
            #print(self.dataset.dataset[id[i]]['costs'])
            #print(self.dataset.dataset[id[i]]["cost_returns"])
            self.dataset.dataset[id[i]]["cost_returns"] = self.discounted_cumsum(self.dataset.dataset[id[i]]['costs'],gamma=1).astype(np.float32)
            #print(self.dataset.dataset[id[i]]["cost_returns"])
        self.trainloader_iter = iter(self.trainloader)

    def checkpoint_fn(self):
            return {"model_state": self.model.state_dict()}
    def train(self):
        for step in trange(self.args.update_steps, desc="Training"):
            batch = next(self.trainloader_iter)
            states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
                b.to(self.args.device).to(dtype=torch.float32) for b in batch
            ]
            
            time_steps = time_steps.to(dtype=torch.int32)

            trainlog = self.trainer.train_one_step(states, actions, returns, costs_return, time_steps, mask,
                                episode_cost, costs)
            self.wandb.log(trainlog)
            # evaluation
            if (step + 1) % self.args.eval_every == 0 or step == self.args.update_steps - 1:
                average_reward, average_cost = [], []
                log_cost, log_reward, log_len = {}, {}, {}
                logs = {}
                for target_return in self.args.target_returns:
                    reward_return, cost_return = target_return
                    val_df = get_val_path(self.data_val)
                    if self.args.cost_reverse:
                        # critical step, rescale the return!
                        ret, cost, length = self.trainer.evaluate(val_df,
                            self.args.eval_episodes, reward_return * self.args.reward_scale,
                            (self.args.episode_len - cost_return) * self.args.cost_scale)
                    else:
                        ret, cost, length = self.trainer.evaluate(val_df,
                            self.args.eval_episodes, reward_return * self.args.reward_scale,
                            cost_return * self.args.cost_scale)
                    average_cost.append(cost)
                    average_reward.append(ret)

                    name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
                    # log_cost.update({name: cost})   # phy action
                    # log_reward.update({name: ret})  # agent action
                    # log_len.update({name: length}) # mse
                    logs['CDT_eval/cost:'] = cost
                    logs['CDT_eval/ret:'] = ret
                    logs['CDT_eval/length:'] = length
                    self.wandb.log(logs)

                # self.logger.store(tab="cost", **log_cost)
                # self.logger.store(tab="ret", **log_reward)
                # self.logger.store(tab="length", **log_len)

                # save the current weight
                #self.logger.save_checkpoint()
                # save the best weight
                # mean_ret = np.mean(average_reward)
                # mean_cost = np.mean(average_cost)
                # if mean_cost < best_cost or (mean_cost == best_cost
                #                             and mean_ret > best_reward):
                #     best_cost = mean_cost
                #     best_reward = mean_ret
                #     best_idx = step
                    #self.logger.save_checkpoint(suffix="best")

                # self.logger.store(tab="train", best_idx=best_idx)

                # self.logger.write(step, display=False)

            #else:
                #self.logger.write_without_reset(step)