import torch
import numpy as np
import pickle
import pandas as pd
def get_val_path(path,batch_size=1):
    num_trajectories = len(path)
    batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=int(batch_size),
            replace=True,
        )
    paths = []
    for i in range(int(batch_size)):
        traj = path[batch_inds[i]]
        paths.append(traj)
    return paths

def get_iv_vaso(actions):
    ac_dt_iv = []
    ac_dt_vaso = []
    for i in range(len(actions)):
        w = actions[i]
        ac_dt_iv.append(w[0])
        ac_dt_vaso.append(w[1])
    return ac_dt_iv,ac_dt_vaso

def get_generate_data(vent,generate_data,cdt,trajectories,target_return=None,target_cost=None,violation=False,batch_size=1000,device='cuda'):
    paths = get_val_path(trajectories,batch_size=batch_size)
    
    generate_data.model.to(device)
    if isinstance(cdt.max_action,int):
        act_max = cdt.max_action
    else:
        act_max = cdt.max_action[0]
    
    
    #state_mean = torch.from_numpy(state_mean).to(device=device)     
    #state_std = torch.from_numpy(state_std).to(device=device)   
    #state_mean, state_std = np.array(0),np.array(1)

    target_r = target_return
    len_paths = len(paths)

    action_agent = []
    action_phy = []
    state_agent = []
    state_phy = []
    die = []
    reward_agent = []
    reward_phy = []

    done = []

    for i in range(len_paths):
        states = paths[i]['observations']
        next_states = paths[i]['next_observations']
        actions = paths[i]['actions']
        rewards = paths[i]['rewards']
        d = paths[i]['dieds'][0]

        state = states[0]
        action = actions[0]
        r0 = rewards[0]

        if violation:
            
            states_v = torch.zeros(1,
                                cdt.model.seq_len + 1,
                                cdt.state_dim,
                                dtype=torch.float,
                                device=cdt.device)
            actions_v = torch.zeros(1,
                                cdt.model.seq_len,
                                cdt.act_dim,
                                dtype=torch.float,
                                device=cdt.device)
            returns_v = torch.zeros(1,
                                cdt.model.seq_len + 1,
                                dtype=torch.float,
                                device=cdt.device)
            costs_v = torch.zeros(1,
                                cdt.model.seq_len + 1,
                                dtype=torch.float,
                                device=cdt.device)
            time_steps_v = torch.arange(cdt.model.seq_len,
                                    dtype=torch.long,
                                    device=cdt.device)
            time_steps_v = time_steps_v.view(1, -1)
            epi_cost_v = torch.tensor(np.array([target_cost]),
                                    dtype=torch.float,
                                    device=device)
            states_v[:, 0] = torch.as_tensor(state, device=cdt.device)
            returns_v[:, 0] = torch.as_tensor(target_r, device=cdt.device)
            costs_v[:, 0] = torch.as_tensor(target_cost, device=cdt.device)
            s = states_v[:, :1][:, -cdt.model.seq_len:]  # noqa
            a = actions_v[:, :1][:, -cdt.model.seq_len:]  # noqa
            r = returns_v[:, :1][:, -cdt.model.seq_len:]  # noqa
            c = costs_v[:, :1][:, -cdt.model.seq_len:]  # noqa
            t = time_steps_v[:, :1][:, -cdt.model.seq_len:]  # noqa
            acts, _, _ = cdt.model(s, a, r, c, t, None, epi_cost_v)
            if cdt.model.stochastic:
                acts = acts.mean
            acts = torch.clamp(acts,0,act_max)
            act = acts[0, -1].cpu().detach().numpy()
            # vent 
            if vent:
                act = np.round(act).astype(int)

            action = act

            actions_v[:, 0] = torch.as_tensor(action, device=cdt.device)
        
        actions_dt = torch.from_numpy(action).reshape(1, generate_data.act_dim).to(device=device, dtype=torch.float32)
        states_dt = torch.from_numpy(state).reshape(1, generate_data.state_dim).to(device=device, dtype=torch.float32)
        
        
        #rewards_dt = torch.tensor(r, dtype=torch.float32, device=device).reshape(1, 1)
        rewards_dt = torch.zeros(0,device=device,dtype=torch.float32)
        r0 = torch.tensor([r0], device=device, dtype=torch.float32)
        
        # 添加 r 到 rewards_dt 中
        rewards_dt = torch.cat([rewards_dt, r0])

        ep_return = target_r
        target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
        timesteps = torch.tensor(1, device=device, dtype=torch.long).reshape(1, 1)
        

        num = min(len(states),cdt.model.seq_len)-1
        # CDT: s0,R_t,C_t——>a0  GDT: s0,r0,a0——>s1,r1
        # CDT: s0,r,c,a0,s1,r1,c->a1
        for j in range(num):
        
            state_preds,action_preds,return_preds,_ = generate_data.model.get_action(
                states_dt.to(dtype=torch.float32),
                actions_dt.to(dtype=torch.float32),
                rewards_dt.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),  
            )  
            rewards_dt = torch.cat([rewards_dt, torch.zeros(1, device=device)])
            returnp = return_preds[0,-1,0]
            rewards_dt[-1] = returnp
            statep = state_preds[0,-1]

            if violation == True: # cdt a
                states_v[:, j+1] = torch.as_tensor(statep, device=cdt.device) # from gdt s
                returns_v[:, j+1] = torch.as_tensor(returnp, device=cdt.device)
                costs_v[:, j+1] = torch.as_tensor(target_cost, device=cdt.device)
                s = states_v[:, :j+2][:, -cdt.model.seq_len:]  # noqa
                a = actions_v[:, :j+2][:, -cdt.model.seq_len:]  # noqa
                r = returns_v[:, :j+2][:, -cdt.model.seq_len:]  # noqa
                c = costs_v[:, :j+2][:, -cdt.model.seq_len:]  # noqa
                t = time_steps_v[:, :j+2][:, -cdt.model.seq_len:]  # noqa
                acts, _, _ = cdt.model(s, a, r, c, t, None, epi_cost_v)
                if cdt.model.stochastic:
                    acts = acts.mean
                acts = torch.clamp(acts,0,act_max)
                 # vent 
                actionp = acts[0, -1]
                if vent:
                    actionp = actionp.round().int()
                actions_v[:, j+1] = torch.as_tensor(actionp, device=cdt.device)
            else:
                # if generate_data.model.vent:
                #     actionp = action_preds
                # else:
                actionp = action_preds[0, -1]
                if vent:
                    actionp = actionp.round().int()

            actions_dt = torch.cat([actions_dt, torch.zeros((1, generate_data.act_dim), device=device)], dim=0)
            actions_dt[-1] = actionp

            if generate_data.mode != 'delayed':
                treturnp = target_return[0,-1] - 0.99*returnp
            else:
                treturnp = target_return[0,-1]
            
            
            target_return = torch.cat([target_return, treturnp.view(1,1)], dim=1)
            timesteps = torch.cat([timesteps,torch.ones((1, 1), device=device, dtype=torch.long) * (j+2)], dim=1)

            #sum_agent += returnp.item()
            #sum_phy += rewards[j]

            if j < num-1:
                state_agent.append(statep.detach().cpu().numpy().tolist())
            else:
                if len(done) == 0:
                    break
                else:
                    done[-1] = 1
                break
            action_agent.append(actionp.detach().cpu().numpy())
            die.append(d)
            done.append(0)
            reward_agent.append(returnp.detach().cpu().numpy())

            if j != num-1:
                states_dt = torch.cat([states_dt,statep.view(1,generate_data.state_dim)],dim=0)
    
    return reward_agent,state_agent,action_agent,die,done

def dtdata_val_high_reward(state_agent_violate,die_violate,done_violate,reward_agent_violate,action_agent_violate):
    trajectories=[]
    obs = []
    next_states = []
    actions = []
    rewards = []
    dones = []
    dieds =[]
    length=[]
    path_r=[]
    for i in range(len(state_agent_violate)):
        ob = state_agent_violate[i]
        r = reward_agent_violate[i] # reward
        # iv = df1.loc[i, 'iv']  #iv_input
        # vaso = df1.loc[i, 'vaso'] # vaso_input
        action = action_agent_violate[i]
        die = die_violate[i] #died_in_hosp

        done = done_violate[i]
        obs.append(ob)
        if done == 0:
            next_states.append(state_agent_violate[i+1])
        else:
            next_states.append(np.zeros(len(ob)))
        actions.append(action)
        rewards.append(r)
        dones.append(done)
        dieds.append(die)
        if done == 1 and len(actions)>0:
            path = dict({'observations': np.array(obs),
                         'next_observations': np.array(next_states),
                         'actions': np.array(actions),
                         'rewards': np.array(rewards),
                         'terminals': np.array(dones),
                         'dieds':np.array(dieds)})
            trajectories.append(path)
            path_r.append(sum(rewards))
            length.append(len(obs))
            obs = []
            next_states = []
            actions = []
            rewards = []
            dones = []
            dieds = []
        elif done == 1:
            obs = []
            next_states = []
            actions = []
            rewards = []
            dones = []
            dieds =[]
            
    return trajectories,length,path_r
def get_violation_data(vent,generate_data,cdt,trajectories,target_cost,target_return,batch_size=1000,device='cuda'):
    generate_data.model.eval()
    cdt.model.eval()
    reward_agent_violate,state_agent_violate,action_agent_violate,die_violate,done_violate = get_generate_data(vent,generate_data,cdt,trajectories,target_return,target_cost,violation=True,batch_size=batch_size/2,device='cuda')
    reward_agent,state_agent,action_agent,die,done = get_generate_data(vent,generate_data,cdt,trajectories,target_return,violation=False,batch_size=batch_size/2,device='cuda')
    # df_agent_state = pd.DataFrame(state_agent+state_agent_violate,columns=state_features)
    # df_agent_state['die']=die+die_violate
    # df_agent_state['done'] = done+done_violate
    # df_agent_state['reward'] = reward_agent+reward_agent_violate
    action = action_agent+action_agent_violate
    state = state_agent+state_agent_violate
    die = die+die_violate
    done = done+done_violate
    reward_agent = reward_agent+reward_agent_violate
    # ac_dt_iv,ac_dt_vaso=get_iv_vaso(action_agent+action_agent_violate)
    # df_agent_state['iv'] = ac_dt_iv
    # df_agent_state['vaso'] = ac_dt_vaso
    path_cql_df,length,path_r = dtdata_val_high_reward(state,die,done,reward_agent,action)
    generate_data.model.train()
    cdt.model.train()
    return path_cql_df

def get_violation_data_pre(vent,generate_data,cdt,trajectories,target_cost,target_return,batch_size=1000,device='cuda'):
    generate_data.model.eval()
    cdt.model.eval()
    reward_agent_violate,state_agent_violate,action_agent_violate,die_violate,done_violate = get_generate_data(vent,generate_data,cdt,trajectories,target_return,target_cost,violation=False,batch_size=batch_size,device='cuda')
    # df_agent_state = pd.DataFrame(state_agent_violate,columns=state_features)
    # df_agent_state['die']=die_violate
    # df_agent_state['done'] = done_violate
    # df_agent_state['reward'] = reward_agent_violate
    # ac_dt_iv,ac_dt_vaso=get_iv_vaso(action_agent_violate)
    # df_agent_state['iv'] = ac_dt_iv
    # df_agent_state['vaso'] = ac_dt_vaso
    path_cql_df,length,path_r = dtdata_val_high_reward(state_agent_violate,die_violate,done_violate,reward_agent_violate,action_agent_violate)
    generate_data.model.train()
    cdt.model.train()
    return path_cql_df