from train_icrl import train_icrl
from train_cdt import train_cdt
from train_generate_data import train_generate_data
import torch
import wandb
import random
import pandas as pd
import numpy as np
import argparse
import torch
from ops import get_violation_data,get_violation_data_pre
import pickle
def get_path(path):
    num_trajectories = len(path)

    paths = []

    actions = []
    costs = []
    next_observations = []
    observations = []
    rewards = []
    terminals = []
    dieds = []

    for i in range(num_trajectories):
        for j in range(len(path[i]['actions'])):
            actions.append(path[i]['actions'][j])
            next_observations.append(path[i]['next_observations'][j])
            observations.append(path[i]['observations'][j])
            terminals.append(path[i]['terminals'][j])
            rewards.append(path[i]['rewards'][j])
            dieds.append(path[i]['dieds'][j])
            costs.append(0)


    paths = dict({'actions': np.array(actions),'next_observations': 
                 np.array(next_observations),'observations': np.array(observations),
                 'rewards': np.array(rewards),'terminals': np.array(terminals),'costs':np.array(costs),
                 'dieds':np.array(dieds)})
        
    return paths
def train(exp_prefix,variant,):

    # dataset_path_e_val = f'./data/sepsis_data/expert_data_val_s.pkl'
    # dataset_path_e = f'./data/sepsis_data/expert_data_s.pkl'
    # dataset_path_e = f'../Process_data/Vent_data/train_vent_data_1Daction.pkl'
    # dataset_path_e_val = f'../Process_data/Vent_data/val_vent_data_1Daction.pkl'

    dataset_path_e = f'../Process_data/Vent_data/train_vent_data_3Ddiscretized_actions.pkl' #../Process_data/Vent_data/train_vent_data_1Daction.pkl'
    dataset_path_e_val = f'../Process_data/Vent_data/val_vent_data_3Ddiscretized_actions.pkl'#'../Process_data/Vent_data/val_vent_data_1Daction.pkl'
    with open(dataset_path_e, 'rb') as f:
        trajectories_expert = pickle.load(f)
        
    with open(dataset_path_e_val, 'rb') as f:
        trajectories_expert_val = pickle.load(f)

    print(len(trajectories_expert),len(trajectories_expert_val))

    env_name, dataset = variant['env'], variant['dataset']
    group_name = f'{exp_prefix}-{env_name}-{dataset}'
    # if vent
    vent = True
    wandb.init(
            name=exp_prefix,
            group=group_name,
            project='decision-transformer_vent',
            config=variant
    )
    violation_target_cost = variant['violation_target_cost']
    violation_target_return = variant['violation_target_return']
    
    # # pre - train generate_model
    generate_model = train_generate_data(trajectories_expert,trajectories_expert_val,variant=variant,wandb=wandb,vent=vent)
    generate_model.train(variant['G_pre_train_iters'])

    cdt_data = get_path(trajectories_expert)
    cdt_data_val = get_path(trajectories_expert_val)

    cdt = train_cdt(cdt_data,cdt_data_val,variant=variant,wandb=wandb,vent=vent)
    trajectories_violation = get_violation_data_pre(vent,generate_model,cdt,trajectories_expert,target_cost=violation_target_cost,target_return=violation_target_return,batch_size=2000)
    # QUESTION
    trajectories_violation_val = get_violation_data_pre(vent,generate_model,cdt,trajectories_expert,target_cost=violation_target_cost,target_return=violation_target_return,batch_size=600)
    
    icrl = train_icrl(trajectories_expert,trajectories_expert_val,trajectories_violation,trajectories_violation_val,variant=variant,wandb=wandb)
    
    numstep = variant['train_num_steps']
    # icrl.train()
    for i in range(numstep):
        print('+' * 40,i,'+'*40)
        # update train 60% cost train cdt
        icrl.train()
        cdt.update_train_cost(icrl,update_batch=0.6) 
        cdt.train()
        # train generate data model
        #if i % 5 == 0:
        generate_model.train(variant['G_train_iters'])
            # generate violate data
        violation_data = get_violation_data(vent,generate_model,cdt,trajectories_expert,target_cost=violation_target_cost,target_return=violation_target_return,batch_size=800)
            # add violation data 2
        icrl.set_violation_data(violation_data)
    
    path_icrl = f'./My_model/vent_icrl_3D_0509_loss.pt'
    path_cdt = f'./My_model/vent_cdt_3D_0509_loss.pt'
    path_generate = f'./My_model/vent_generation_3D_0509_loss.pt'

    #torch.save(icrl.model.state_dict(),path_icrl)
    torch.save(cdt.model.state_dict(),path_cdt)
    #torch.save(generate_model.model.state_dict(),path_generate)





if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='EnvSepsis-v1')
    parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
    parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
    parser.add_argument('--G_K', type=int, default=20)
    parser.add_argument('--G_pct_traj', type=float, default=1.)
    parser.add_argument('--G_batch_size', type=int, default=256)  #64 1024
    parser.add_argument('--G_model_type', type=str, default='dt')  # dt for decision transformer, bc for behavior cloning
    parser.add_argument('--G_embed_dim', type=int, default=128)  #128 hidden size
    parser.add_argument('--G_n_layer', type=int, default=3)# 3\4
    parser.add_argument('--G_n_head', type=int, default=8) # 8\1
    parser.add_argument('--G_activation_function', type=str, default='relu')
    parser.add_argument('--G_dropout', type=float, default=0.1)  # 让某些节点停止工作，防止过拟合 0.1
    parser.add_argument('--G_learning_rate', '-lr', type=float, default=1e-4)  # 1e-6
    parser.add_argument('--G_weight_decay', '-wd', type=float, default=1e-6) # 1e-4网络权值衰减，防止过拟合——去掉  # 无
    parser.add_argument('--G_warmup_steps', type=int, default=1000) #1000  # x次后衰减学习率——去掉  
    parser.add_argument('--G_num_eval_episodes', type=int, default=30) #100
    parser.add_argument('--G_pre_train_iters', type=int, default=30) # 30-------------------------------------50
    parser.add_argument('--G_num_steps_per_iter', type=int, default=500) #1000 env_targets
    parser.add_argument('--G_env_targets', type=int, default=20) #max_ep_len
    parser.add_argument('--G_max_ep_len', type=int, default=20)
    parser.add_argument('--G_train_iters', type=int, default=5)  # 5-------------

    parser.add_argument('--ICRL_K', type=int, default=10)
    parser.add_argument('--act_max', type=int, default=[7,7,7]) #sepsis[1,1] vent[1000000000,1000000000,1000000000]
    parser.add_argument('--act_dim', type=int, default=3) # sepsis-2 vent-3 
    parser.add_argument('--state_dim', type=int, default=38)  #sepsis-48 vent-38
    parser.add_argument('--ICRL_max_ep_len', type=int, default=10)
    parser.add_argument('--ICRL_pct_traj', type=float, default=1.)
    parser.add_argument('--ICRL_batch_size', type=int, default=512)  # 2048
    parser.add_argument('--ICRL_model_type', type=str, default='dt')  # dt for decision transformer, bc for behavior cloning
    parser.add_argument('--ICRL_embed_dim', type=int, default=64)  #128 hidden size
    parser.add_argument('--ICRL_n_layer', type=int, default=3) # 4 3
    parser.add_argument('--ICRL_n_head', type=int, default=1) # 1 8buxing 
    parser.add_argument('--ICRL_activation_function', type=str, default='relu')
    parser.add_argument('--ICRL_dropout', type=float, default=0.1)  # 让某些节点停止工作，防止过拟合 0.1
    parser.add_argument('--ICRL_learning_rate', type=float, default=1e-6)  # 1e-6
    parser.add_argument('--ICRL_weight_decay', type=float, default=1e-6) # 1e-4网络权值衰减，防止过拟合——去掉  # 无
    parser.add_argument('--ICRL_warmup_steps', type=int, default=1000) #1000  # x次后衰减学习率——去掉  
    parser.add_argument('--ICRL_pre_attn_embd_dim', type=int, default=64)
    parser.add_argument('--ICRL_use_weighted_sum', type=bool, default=True) # with attention or without attention
    parser.add_argument('--ICRL_train_type', type=str, default='mean')
    parser.add_argument('--ICRL_max_iters', type=int, default=30) # 30-----------------------------------------20
    parser.add_argument('--ICRL_num_eval_episodes', type=int, default=100) #100
    parser.add_argument('--ICRL_num_steps_per_iter', type=int, default=1000) #200
    
    parser.add_argument('--train_num_steps', type=int, default=5)
    parser.add_argument('--violation_target_cost', type=int, default=0) #violation_target_return
    parser.add_argument('--violation_target_return', type=int, default=30)

    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', type=bool, default=True)

    args = parser.parse_args()

    train('gym-experiment', variant=vars(args))