import json
from os.path import join
from tqdm import tqdm
import numpy as np
import random


from src.policies.gdt_policy import GDTPolicy
import src.utils as utils
from src.envs.toy_car.toy_car import ToyCar
import wandb
import pickle
import torch


class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    visiable: str = 'whole'
    config: str = 'config.offline'


# seed, trained model, dataset

#######################
######## setup ########
#######################
config_defaults = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 300000, 
        "env_name": "diagnosis_RL_dqn",
        'batch_size': 64, #8,
        'learning_rate':0.0001,
        'n_layer':3,
        'embed_dim':128
        
      
    }
def evaluation():

    args = Parser().parse_args('bc_plan')


    utils.set_device(args.device)
    run = wandb.init(
        
        config=config_defaults,
        sync_tensorboard=False, #project=wandb.config.base['WANDB_PROJECT'],
        save_code=True,# name = args.exp_name
    )

  


    # whole
    # data_index = [951, 951, 653, 659, 661 ] # old
    # model_index = [1334, 1339, 1345, 1347, 1351 ]
    # state_index = [186, 181, 179, 181, 180 ]
    # bcq_model_index = [1298, 1304, 1311, 1318, 1327 ]
    # bcq_state_index = [179, 174, 172, 175, 175 ]
    # stop_sign = 30




    # data_index = [953, 964, 972, 976, 980 ] 
    # model_index = [1341, 1352, 1360, 1364, 1368 ]
    # state_index = [192, 187, 185, 187, 186 ]
    # bcq_model_index = [1299, 1306, 1314, 1322, 1329 ]
    # bcq_state_index = [180, 176, 175, 177, 177 ]
    # stop_sign = 40


    # data_index = [947, 955, 965, 973, 977 ] 
    # model_index = [1335, 1343, 1353, 1361, 1365 ]
    # state_index = [188, 184, 182, 184, 183 ]
    # bcq_model_index = [1300, 1307, 1313, 1320, 1325 ]
    # bcq_state_index = [181, 175, 173, 174, 173 ]
    # stop_sign = 50

    # data_index = [948, 958, 967, 974, 978 ]
    # model_index = [1336, 1346, 1355, 1362, 1366 ]
    # state_index = [190, 185, 183, 185, 184 ]
    # bcq_model_index = [1301, 1308, 1315, 1321, 1328 ]
    # bcq_state_index = [182, 177, 174, 176, 174 ]
    # stop_sign = 60


    # data_index = [949, 954, 960, 966, 969 ] 
    # model_index = [1337, 1342, 1348, 1354, 1357 ]
    # state_index = [187, 182, 180, 182, 181 ]
    # bcq_model_index = [1302, 1309, 1316, 1323, 1330 ]
    # bcq_state_index = [183, 178, 176, 178, 176 ]
    # stop_sign = 70


    # data_index = [950, 956, 961, 968, 970 ] 
    # model_index = [1338, 1344, 1349, 1356, 1358 ]
    # state_index = [189, 183, 181, 183, 182 ]
    # bcq_model_index = [1303, 1310, 1317, 1324, 1331 ]
    # bcq_state_index = [184, 179, 177, 179, 178 ]
    # stop_sign = 80

    data_index = [952, 962, 971, 975, 979 ]
    model_index = [1340, 1350, 1359, 1363, 1367 ]
    state_index = [191, 186, 184, 186, 185 ]
    bcq_model_index = [1305, 1312, 1319, 1326, 1332 ]
    bcq_state_index = [185, 180, 178, 180, 179 ]
    stop_sign = 80


    # data_index = [953, 964, 972, 659, 661 ] 
    # model_index = [1341, 1352, 1360, 1347, 1351 ]
    # state_index = [192, 187, 179, 181, 180 ]
    # bcq_model_index = [1305, 1312, 1319, 1326, 1332 ]
    # bcq_state_index = [185, 180, 178, 180, 179 ]
    # stop_sign = random


    seeds = [0, 10, 20, 30, 40]
    
    for i in range(5):
        seed = seeds[i]
        data_artifact = run.use_artifact('your-folder/data_config:v'+str(data_index[i]), type='config')
        loadpath = data_artifact.download()
        dataset = pickle.load(open(loadpath+'/data_config.pkl', 'rb')).make()

        print('dataset: ', dataset)
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[i]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(state_index[i]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(seed)+'.pt')

        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)



        bcq_model_artifact = run.use_artifact('your-folder/model_config:v'+str(bcq_model_index[i]), type='config')
        bcq_config_path = bcq_model_artifact.download()
        bcq_config = pickle.load(open(bcq_config_path+'/model_config.pkl', 'rb'))

        bcq_state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(bcq_state_index[i]), type='model')
        bcq_state_path = bcq_state_artifact.download()
        bcq_state = torch.load(bcq_state_path+'/state_48_seed'+str(seed)+'.pt')

        bcq_gpt = bcq_config()
        bcq_gpt.to(args.device)
        bcq_gpt.load_state_dict(bcq_state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {bcq_config_path}\n')
        print(bcq_config)

    # #######################
    # ####### models ########
    # #######################

    # args.logbase = args.logbase + 'dt/' # ~/SPLT-transformer/logs/bt/
    # args.exp_name = args.gpt_loadpath + '/' + args.exp_name
    # args.savepath = join(args.logbase, args.dataset, args.exp_name)
    # print('logbase: ', args.logbase)
    # print('dataset: ', args.dataset)
    # print('savepath: ', args.savepath)
    # # dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
    # #         'data_config.pkl')

    # # gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
    # #         epoch=args.gpt_epoch, device=args.device)

    # # #######################
    # # ####### dataset #######
    # # #######################
      


        random.seed(seed)
        np.random.seed(seed)
        env = ToyCar(stop_sign = stop_sign, seed = seed)

        discount = dataset.discount
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim
       

        #######################
        ###### main loop ######
        #######################

        returns = []
        successes = []

        T = 40
        num_episodes = 100
        max_history = args.max_context_transitions

        gpt.eval()

        policy = GDTPolicy(
            gpt, bcq_gpt,
            observation_dim,
            action_dim,
            discount,
            max_history=max_history,
            device=args.device)

        vis_threshold = 25
        percent_reward = []
        for i in tqdm(range(num_episodes)):
            observation = env.reset(testing=True)
            policy.reset()
            total_reward = 0

            for t in range(T):
              

                action = policy(observation)
              

                ## execute action in environment
                next_observation, reward, terminal,crash, info, ego_x, other_x = env.test_step(True, action)
                # False: terminate early due to crush
                # True: no terminal


                if args.visiable == 'constant':
                    print('[plan/visiable] constant')
                    
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = next_observation[0] + vis_threshold
                        next_observation[3] = next_observation[1] 
                        print('next obeservation: ', next_observation) 
            
                elif args.visiable == 'zero_value':
                    print('[plan/visiable] zero value')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = 0
                        next_observation[3] = 0
                        print('next obeservation: ', next_observation)

                elif args.visiable == 'random_noise':
                    print('[train/visiable] random noise')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = vis_threshold + 9 * random.random()
                        next_observation[3] = 10 * random.random()
                        print('next obeservation: ', next_observation)
                    
                

                policy.update_context(observation, action, reward)

                ## update return
                total_reward += reward

                if terminal:
                    returns.append(total_reward)
                    successes.append(info['success'])
                    # print(total_reward, t)
                    # print('ego_x: ', ego_x)
                    # print('other_x: ', other_x)
                    # print(ego_x/other_x)
                    percent_reward.append(ego_x/other_x)
                    break

                observation = next_observation
        print('Percent Reward {0}'.format(np.mean(percent_reward)))
        print('Mean Return {0}'.format(np.mean(returns)))
        print('Std Return {0}'.format(np.std(returns)))
        print('Mean Success Rate {0}'.format(np.mean(successes)))

        utils.serialization.mkdir(args.savepath)
        json_path = join(args.savepath, 'rollout.json')
        json_data = {
            'Mean Return': np.mean(returns),
            'Std Return': np.std(returns),
            'Mean Success Rate': np.mean(successes),
        }
        json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)

if __name__ == '__main__':
    evaluation()
