import gym
import numpy as np
import torch
import wandb
from torch.nn import functional as F  # noqa
import argparse
import dill as pickle
import random
import sys
print(sys.path)
from GDT.decision_transformer2.evaluation.evaluate_episodes import evaluate_episode_bc,evaluate_episode_rtg_action,my_evaluate_episode_rtg,evaluate_episode_rtg_2, evaluate_episode_rtg
from GDT.decision_transformer2.models.decision_transformer import DecisionTransformer
from GDT.decision_transformer2.models.mlp_bc import MLPBCModel
from GDT.decision_transformer2.training.act_trainer import ActTrainer
from GDT.decision_transformer2.training.seq_trainer import SequenceTrainer
import pandas as pd
def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum

def get_val_path(path,batch_size=20):
    num_trajectories = len(path)
    batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
        )
    paths = []
    for i in range(batch_size):
        traj = path[batch_inds[i]]
        paths.append(traj)
    return paths
class train_generate_data:
    def __init__(self,trajectories,trajectories_val,variant,wandb,vent):
        self.variant = variant
        self.device = self.variant.get('device', 'cuda')
        
        # env = gym.make('EnvSepsis-v1')
        self.max_ep_len = variant['G_max_ep_len']
        self.env_targets = [variant['G_env_targets']]
        self.scale = 1.

        # load_data_df(env)
        self.state_dim = variant['state_dim']
        self.act_dim = variant['act_dim']
        self.act_max = variant['act_max']
        # action [0,1] reward [0,1] 14639 traj
        
        # self.dataset_path_val = f'./data/my_dt_test_48_normalization.pkl'
        # self.dataset_path = f'./data/my_dt_train_48_normalization.pkl'
        # with open(self.dataset_path, 'rb') as f:
        #     self.trajectories = pickle.load(f)

        # with open(self.dataset_path_val,'rb') as f:
        #     self.trajectories_val = pickle.load(f)
        self.trajectories = trajectories
        self.trajectories_val = trajectories_val

        self.vent = vent

        self.mode = variant.get('mode', 'normal')
        self.states, self.traj_lens, self.returns = [], [], []
        for path in self.trajectories:
            if self.mode == 'delayed':  # delayed: all rewards moved to end of trajectory
                path['rewards'][-1] = path['rewards'].sum()
                path['rewards'][:-1] = 0.
            self.states.append(path['observations'])
            self.traj_lens.append(len(path['observations']))
            self.returns.append(path['rewards'].sum())
        self.traj_lens, self.returns = np.array(self.traj_lens), np.array(self.returns)


        # used for input normalization
        self.states = np.concatenate(self.states, axis=0)
        #state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        self.state_mean, self.state_std = np.array(0),np.array(1)

        num_timesteps = sum(self.traj_lens)

        self.K = variant['G_K']
        self.batch_size = variant['G_batch_size']
        self.num_eval_episodes = variant['G_num_eval_episodes']
        self.pct_traj = variant.get('pct_traj', 1.)

        # only train on top pct_traj trajectories (for %BC experiment)
        self.num_timesteps = max(int(self.pct_traj*num_timesteps), 1)
        self.sorted_inds = np.argsort(self.returns)  # lowest to highest 奖励由低到高排序的，轨迹序列号
        self.num_trajectories = 1
        self.timesteps = self.traj_lens[self.sorted_inds[-1]]
        self.ind = len(self.trajectories) - 2
        while self.ind >= 0 and self.timesteps + self.traj_lens[self.sorted_inds[self.ind]] <= self.num_timesteps:
            self.timesteps += self.traj_lens[self.sorted_inds[self.ind]]
            self.num_trajectories += 1
            self.ind -= 1
        self.sorted_inds = self.sorted_inds[-self.num_trajectories:]

        # used to reweight sampling so we sample according to timesteps instead of trajectories
        self.p_sample = self.traj_lens[self.sorted_inds] / sum(self.traj_lens[self.sorted_inds])
        self.model = DecisionTransformer(
            state_dim=self.state_dim,
            act_dim=self.act_dim,
            act_max=self.act_max,
            max_length=self.K,
            max_ep_len=self.max_ep_len,
            target_entropy = -self.act_dim,
            hidden_size=variant['G_embed_dim'],
            n_layer=variant['G_n_layer'],
            n_head=variant['G_n_head'],
            n_inner=4*variant['G_embed_dim'],
            activation_function=variant['G_activation_function'],
            n_positions=1024,
            resid_pdrop=variant['G_dropout'],
            attn_pdrop=variant['G_dropout'],
            vent = self.vent
        )

        self.model = self.model.to(device=self.device)

        self.warmup_steps = variant['G_warmup_steps'] 
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=variant['G_learning_rate'],
            weight_decay=variant['G_weight_decay'],
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(  #learning rate
            self.optimizer,
            lambda steps: min((steps+1)/self.warmup_steps, 1)
        )
        self.trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            batch_size=self.batch_size, 
            get_batch=self.get_batch,
            scheduler=self.scheduler,
            loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
            eval_fns=[self.eval_episodes(tar) for tar in self.env_targets],
        )

    def get_batch(self,batch_size=256):
        max_len = self.K
        batch_inds = np.random.choice(
            np.arange(self.num_trajectories),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        
        s,s_next, a, r, d, rtg, timesteps, mask = [],[], [], [], [], [], [], []
        for i in range(batch_size):
            traj = self.trajectories[int(self.sorted_inds[batch_inds[i]])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            #s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
            s_next.append(traj['observations'][si:si + max_len].reshape(1, -1, self.state_dim))
            a.append(traj['actions'][si:si + max_len].reshape(1, -1, self.act_dim))
            r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
            #r_next.append(traj['rewards'][si+1:si + max_len + 1].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
            else:
                d.append(traj['dones'][si:si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s_next[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len-1  # padding cutoff
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s_next[-1].shape[1]].reshape(1, -1, 1))
            if rtg[-1].shape[1] <= s_next[-1].shape[1]-1:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s_next[-1].shape[1]
            #xlen = s_next[-1].shape[1]
            
            s_next[-1] = np.concatenate([np.zeros((1, max_len - tlen, self.state_dim)), s_next[-1]], axis=1)
            s_next[-1] = (s_next[-1] - self.state_mean) / self.state_std
            #print(s_next[-1][:,:-1,:])
            s.append(s_next[-1][:,:-1,:])
            xlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, max_len - xlen, self.state_dim)), s[-1]], axis=1)
            a[-1] = np.concatenate([np.ones((1, max_len - tlen+1, self.act_dim)) * -10., a[-1][:,:-1, :]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen+1, 1)), r[-1][:,:-1, :]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen+1)) * 2, d[-1][:,1:]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen+1, 1)), rtg[-1][:,:-1, :]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen+1)), timesteps[-1][:,:-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen+1)), np.ones((1, tlen-1))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.device)
        s_next = torch.from_numpy(np.concatenate(s_next, axis=0)).to(dtype=torch.float32, device=self.device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.device)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=self.device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=self.device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=self.device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.device)

        return s, a, r, d, rtg, timesteps, mask,s_next

    def eval_episodes(self,target_rew):
        def fn(model):
            action_mean_eval = []
            for _ in range(self.num_eval_episodes):
                val_df = get_val_path(self.trajectories_val,batch_size=64)
                with torch.no_grad():
                    reward_agent,reward_phy,state_agent,state_phy,action_agent,action_phy,die,act_loss,a,b,c = my_evaluate_episode_rtg(
                        self.state_dim,
                        self.act_dim,
                        model,
                        val_df,
                        max_ep_len=self.max_ep_len,
                        scale=self.scale,
                        target_return=target_rew/self.scale,
                        mode=self.mode,
                        state_mean=self.state_mean,
                        state_std=self.state_std,
                        device=self.device,
                    )
                    
                reward_loss = F.mse_loss(torch.tensor(reward_agent), torch.tensor(reward_phy), reduction="none")
                reward_loss = reward_loss.mean()
                state_loss = F.mse_loss(torch.tensor(state_agent),torch.tensor(state_phy),reduction="none")
                state_loss = state_loss.mean()
                action_loss = F.mse_loss(torch.tensor(action_agent),torch.tensor(action_phy),reduction="none")
                action_loss = action_loss.mean()
                gap_reward = torch.tensor(reward_agent).mean()-torch.tensor(reward_phy).mean()
            return {
                f'reward_loss_eval':reward_loss,
                f'state_loss_eval': state_loss,
                f'action_loss_eval': action_loss,
                f'gap_reward_a_p':gap_reward,
                f'act_loss_eval': act_loss,
                f'sum_agent-phy_iv': a,
                f'sum_agent-phy_vaso': b,
                f'delta_agent-phy_vaso':c,
            }
        return fn
    
    def train(self,max_iters):
        for iter in range(max_iters):
            outputs = self.trainer.train_iteration(self.vent,num_steps=self.variant['G_num_steps_per_iter'], iter_num=iter+1, print_logs=True)
            wandb.log(outputs)