import gym
import numpy as np
import wandb
from torch.nn import functional as F  # noqa
import argparse
import dill as pickle
import random
import sys
import torch
from ICRL.models.DT_ICRL import DecisionTransformer_icrl
from ICRL.models.mlp_bc import MLPBCModel
from ICRL.training.act_trainer import ActTrainer
from ICRL.training.seq_trainer import SequenceTrainer
from ICRL.evaluation.evaluate_episodes import evaluation_cost
import pandas as pd
from ICRL.pre.get_data import process_data,discount_cumsum,get_val_path

# initialize violation,expert,training data
# train ICRL
# generate training data
# train CDT
# generate violation data by DT

class train_icrl:
    def __init__(self,trajectories_expert,trajectories_expert_val,trajectories_violation,trajectories_obs_val,variant,wandb):
        
        self.device = variant.get('device', 'cuda')
        self.variant = variant
        # env = gym.make('EnvSepsis-v1')
        self.max_ep_len = variant['ICRL_max_ep_len']
        self.env_targets = [20]
        self.scale = 1.

        self.state_dim = variant['state_dim']
        self.act_dim = variant['act_dim']
        self.act_max = variant['act_max']
    
        # dataset_path_o_val = f'./data/sepsis_data/violate_data_val_s.pkl'
        # dataset_path_e_val = f'./data/sepsis_data/expert_data_val_s.pkl'
        # dataset_path_o = f'./data/sepsis_data/violate_data_s.pkl'
        # dataset_path_e = f'./data/sepsis_data/expert_data_s.pkl'
        # with open(dataset_path_e, 'rb') as f:
        #     self.trajectories_expert = pickle.load(f)
        
        # with open(dataset_path_o, 'rb') as f:
        #     self.trajectories_obs = pickle.load(f)

        # with open(dataset_path_e_val, 'rb') as f:
        #     self.trajectories_expert_val = pickle.load(f)
        
        # with open(dataset_path_o_val, 'rb') as f:
        #     self.trajectories_obs_val = pickle.load(f)
        self.trajectories_obs = trajectories_violation
        self.trajectories_obs_val = trajectories_obs_val
        self.trajectories_expert = trajectories_expert
        self.trajectories_expert_val = trajectories_expert_val

        self.mode = variant.get('mode', 'normal')
    
        self.p_sample_expert,self.states_expert, self.traj_lens_expert, self.returns_expert,self.num_trajectories_expert,self.sorted_inds_expert = process_data(self.trajectories_expert,variant,'expert_data')
        self.p_sample_obs,self.states_obs, self.traj_lens_obs, self.returns_obs,self.num_trajectories_obs,self.sorted_inds_obs = process_data(self.trajectories_obs,variant,'violate_data')
            
        self.p_sample_expert_val,self.states_expert_val, self.traj_lens_expert_val, self.returns_expert_val,self.num_trajectories_expert_val,self.sorted_inds_expert_val = process_data(self.trajectories_expert_val,variant,'expert_data_val')
        self.p_sample_obs_val,self.states_obs_val, self.traj_lens_obs_val, self.returns_obs_val,self.num_trajectories_obs_val,self.sorted_inds_obs_val = process_data(self.trajectories_obs_val,variant,'violate_data_val')
        
        self.num_eval_episodes = variant['ICRL_num_eval_episodes']
        self.train_type = variant['ICRL_train_type']
        self.use_weighted_sum = variant['ICRL_use_weighted_sum']
        self.K = variant['ICRL_K']
        self.batch_size = variant['ICRL_batch_size']
        self.model = DecisionTransformer_icrl(
            state_dim=self.state_dim,
            act_dim=self.act_dim,
            max_length=self.K,
            max_ep_len=self.max_ep_len,
            target_entropy = -self.act_dim,
            hidden_size=variant['ICRL_embed_dim'],
            n_layer=variant['ICRL_n_layer'],
            n_head=variant['ICRL_n_head'],
            n_inner=4*variant['ICRL_embed_dim'],
            activation_function=variant['ICRL_activation_function'],
            n_positions=1024,
            resid_pdrop=variant['ICRL_dropout'],
            attn_pdrop=variant['ICRL_dropout'],
            pre_attn_embd_dim = variant['ICRL_pre_attn_embd_dim'],
            use_weighted_sum = variant['ICRL_use_weighted_sum'],
        )
        self.model = self.model.to(device=self.device)

        self.warmup_steps = variant['ICRL_warmup_steps'] #  权值衰减，防止过拟合
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=variant['ICRL_learning_rate'],
            weight_decay=variant['ICRL_weight_decay'],
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(  # 改变learning rate
            self.optimizer,
            lambda steps: min((steps+1)/self.warmup_steps, 1)
        )
        #  model, optimizer, batch_size, get_batch_e, get_batch_o, loss_fn,train_type='mean',use_weighted_sum = True, scheduler=None, eval_fns=None
        
        self.trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            batch_size=self.batch_size, 
            get_batch_e=self.get_batch_e,
            get_batch_o=self.get_batch_o,
            train_type= self.train_type,
            use_weighted_sum = self.use_weighted_sum,
            scheduler=self.scheduler,
            loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
            eval_fns=[self.eval_episodes(tar) for tar in self.env_targets],
        )
        self.wandb = wandb
    def set_violation_data(self,violation_data):
        self.trajectories_obs = self.trajectories_obs + violation_data
        l = min(10000,len(self.trajectories_obs))
        self.trajectories_obs = random.sample(self.trajectories_obs,l)
        self.p_sample_obs,self.states_obs, self.traj_lens_obs, self.returns_obs,self.num_trajectories_obs,self.sorted_inds_obs = process_data(self.trajectories_obs,self.variant,'violate_data')

    def train(self):
        for iter in range(self.variant['ICRL_max_iters']):
            outputs = self.trainer.train_iteration(num_steps=self.variant['ICRL_num_steps_per_iter'], iter_num=iter+1, print_logs=True)
            self.wandb.log(outputs)

    def eval_episodes(self,target_rew):
        def fn(model):
            action_mean_eval = []
            for _ in range(self.num_eval_episodes):
                states_e, actions_e, rewards_e, dones_e, rtg_e, timesteps_e, attention_mask_e,batch_inds = self.get_batch_e(batch_size=32) 
                states_o, actions_o, rewards_o, dones_o, rtg_o, timesteps_o, attention_mask_o = self.get_batch_o(batch_size=32) 
                with torch.no_grad():
                    loss = evaluation_cost(model,self.use_weighted_sum,self.train_type,states_e, actions_e, timesteps_e,attention_mask_e,states_o, actions_o, timesteps_o,attention_mask_o)
                return {
                    f'loss_eval':loss,
                }
        return fn
    def get_batch_o(self,batch_size=256,batch_inds=None,max_len=20,eval=False,batch_inds_val=None):
        if eval == False:
            trajs = self.trajectories_obs
            inds = self.sorted_inds_obs
            if batch_inds is None:
                batch_inds = np.random.choice(
                    np.arange(self.num_trajectories_obs),
                    size=batch_size,
                    replace=True,
                    p=self.p_sample_obs,  # reweights so we sample according to timesteps
                )
        else:
            trajs = self.trajectories_obs_val
            inds = self.sorted_inds_obs_val
            if batch_inds_val is None:
                batch_inds = np.random.choice(
                    np.arange(self.num_trajectories_obs_val),
                    size=batch_size,
                    replace=True,
                    p=self.p_sample_obs_val,  # reweights so we sample according to timesteps
                )
            else:
                batch_inds = batch_inds_val

        s,s_next, a, r, d, rtg, timesteps, mask = [],[], [], [], [], [], [], []
        for i in range(batch_size):
            traj = trajs[int(inds[batch_inds[i]])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            s.append(traj['observations'][si:si + max_len].reshape(1, -1, self.state_dim).astype(np.float64))
            #s_next.append(traj['observations'][si+1:si + max_len+1].reshape(1, -1, state_dim))
            a.append(traj['actions'][si:si + max_len].reshape(1, -1, self.act_dim))
            r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
            #r_next.append(traj['rewards'][si+1:si + max_len + 1].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
            else:
                d.append(traj['dones'][si:si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len-1  # padding cutoff
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
            if rtg[-1].shape[1] <= s[-1].shape[1]:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, self.state_dim)), s[-1]], axis=1)
            #s_next[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s_next[-1]], axis=1)
            #s[-1] = (s[-1] - state_mean) / state_std
            #s_next[-1] = (s_next[-1]-state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, self.act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.device)
        #s_next = torch.from_numpy(np.concatenate(s_next, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.device)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=self.device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=self.device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=self.device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.device)

        return s, a, r, d, rtg, timesteps, mask

    def get_batch_e(self,batch_size=256, max_len=20,eval=False):
        if eval == False:
            trajs = self.trajectories_expert
            inds = self.sorted_inds_expert
            batch_inds = np.random.choice(
                np.arange(self.num_trajectories_expert),
                size=batch_size,
                replace=True,
                p=self.p_sample_expert,  # reweights so we sample according to timesteps
            )
        else:
            trajs = self.trajectories_expert_val
            inds = self.sorted_inds_expert_val
            batch_inds = np.random.choice(
                np.arange(self.num_trajectories_expert_val),
                size=batch_size,
                replace=True,
                p=self.p_sample_expert_val,  # reweights so we sample according to timesteps
            )

        s,s_next, a, r, d, rtg, timesteps, mask = [],[], [], [], [], [], [], []
        for i in range(batch_size):
            traj = trajs[int(inds[batch_inds[i]])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            s.append(traj['observations'][si:si + max_len].reshape(1, -1, self.state_dim))
            #s_next.append(traj['observations'][si+1:si + max_len+1].reshape(1, -1, state_dim))
            a.append(traj['actions'][si:si + max_len].reshape(1, -1, self.act_dim))
            r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
            #r_next.append(traj['rewards'][si+1:si + max_len + 1].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
            else:
                d.append(traj['dones'][si:si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len-1  # padding cutoff
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
            if rtg[-1].shape[1] <= s[-1].shape[1]:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, self.state_dim)), s[-1]], axis=1)
            #s_next[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s_next[-1]], axis=1)
            #s[-1] = (s[-1] - state_mean) / state_std
            #s_next[-1] = (s_next[-1]-state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, self.act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.device)
        #s_next = torch.from_numpy(np.concatenate(s_next, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.device)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=self.device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=self.device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=self.device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.device)

        return s, a, r, d, rtg, timesteps, mask,batch_inds

    

    

