import json
from os.path import join
from tqdm import tqdm
import numpy as np
import random

from src.policies.bt_policy import BTPolicy
import src.utils as utils
from src.envs.toy_car.toy_car import ToyCar
import wandb
import pickle
import torch


class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    visiable: str = 'whole'
    config: str = 'config.offline'

#######################
######## setup ########
#######################
config_defaults = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 300000, 
        "env_name": "diagnosis_RL_dqn",
        'batch_size': 64, #8,
        'learning_rate':0.0001,
        'n_layer':3,
        'embed_dim':128
        
      
    }
def evaluation():  
    args = Parser().parse_args('plan')

    utils.set_device(args.device)
    run = wandb.init(
        
        config=config_defaults,
        sync_tensorboard=False, #project=wandb.config.base['WANDB_PROJECT'],
        save_code=True,# name = args.exp_name
    )

    
    # random noise
    # data_index = [823, 827, 832, 836, 840 ]
    # model_index = [1186, 1190, 1195, 1199, 1203 ]
    # state_index = [165, 161, 160, 162, 161 ]
    # stop_sign = 30


    # whole
    # data_index = [657, 648, 651, 655, 657 ]
    # model_index = [1001, 992, 995, 999, 1001 ]
    # state_index = [131, 125, 124, 126, 125 ]
    # stop_sign = 30


    # data_index = [646, 650, 652, 656, 658 ]
    # model_index = [990, 994, 996, 1000, 1002 ]
    # state_index = [133, 126, 125, 127, 126 ]
    # stop_sign = 40

    # data_index = [598, 603, 609, 616, 623 ]
    # model_index = [955, 960, 966, 973, 979 ]
    # state_index = [120, 115, 114, 116, 115 ]
    # stop_sign = 50

    # data_index = [599, 604, 610, 618, 625 ]
    # model_index = [956, 961, 967, 975, 981 ]
    # state_index = [121, 116, 116, 118, 118 ]
    # stop_sign = 60

    data_index = [600, 605, 613, 619, 627 ]
    model_index = [957, 962, 970, 976, 983 ]
    state_index = [122, 118, 118, 120, 119 ]
    stop_sign = 70

    # data_index = [601, 606, 611, 617, 624 ]
    # model_index = [958, 963, 968, 974, 980 ]
    # state_index = [123, 117, 115, 117, 116 ]
    # stop_sign = 80

    # data_index = [602, 607, 614, 620, 626 ]
    # model_index = [959, 964, 971, 977, 982 ]
    # state_index = [124, 119, 117, 119, 117 ]
    # stop_sign = 80

    

    seeds = [0, 10, 20, 30, 40]

    for i in range(5):
       
        seed = seeds[i]

        data_artifact = run.use_artifact('your-folder/data_config:v'+str(data_index[i]), type='config')
        loadpath = data_artifact.download()
        dataset = pickle.load(open(loadpath+'/data_config.pkl', 'rb')).make()

        print('dataset: ', dataset)
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[i]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(state_index[i]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(seed)+'.pt')

        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)

        #######################
        ####### models ########
        #######################

        # args.logbase = args.logbase + 'iql/'
        # args.exp_name = args.gpt_loadpath + '/' + args.exp_name
        # args.savepath = join(args.logbase, args.dataset, args.exp_name)
        # dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
        #         'data_config.pkl')

        # gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
        #         epoch=args.gpt_epoch, device=args.device)

        # #######################
        # ####### dataset #######
        # #######################

    


        random.seed(seed)
        np.random.seed(seed)
        env = ToyCar(stop_sign = stop_sign, seed = seed)

        discount = dataset.discount
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim

        #######################
        ###### main loop ######
        #######################

        returns = []
        successes = []

        T = 40
        num_episodes = 100
        max_history = 1

        gpt.eval()

        policy = BTPolicy(
            gpt,
            observation_dim,
            action_dim,
            discount,
            max_history=max_history,
            device=args.device)
        percent_reward = []
        for i in tqdm(range(num_episodes)):
            observation = env.reset(testing=True)
            policy.reset()
            total_reward = 0

            for t in range(T):

                action = policy(observation)
                
                ## execute action in environment
                # v = gpt.relabel_v
                # print('v: ', v)

                # print('------------------')
                next_observation, reward, terminal,crash, info, ego_x, other_x = env.test_step(False, action)

                if args.visiable == 'constant':
                    print('[plan/visiable] constant')
                    
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = next_observation[0] + vis_threshold
                        next_observation[3] = next_observation[1] 
                        print('next obeservation: ', next_observation) 
            
                elif args.visiable == 'zero_value':
                    print('[plan/visiable] zero value')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = 0
                        next_observation[3] = 0
                        print('next obeservation: ', next_observation)

                elif args.visiable == 'random_noise':
                    print('[train/visiable] random noise')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = vis_threshold + 9 * random.random()
                        next_observation[3] = 10 * random.random()
                        print('next obeservation: ', next_observation)


                ## update return
                total_reward += reward

                if terminal:
                    returns.append(total_reward)
                    successes.append(info['success'])
                    # print(total_reward, t)
                    percent_reward.append(ego_x/other_x)
                    break

                observation = next_observation
        print('Percent Reward {0}'.format(np.mean(percent_reward)))
        print('Mean Return {0}'.format(np.mean(returns)))
        print('Std Return {0}'.format(np.std(returns)))
        print('Mean Success Rate {0}'.format(np.mean(successes)))

        utils.serialization.mkdir(args.savepath)
        json_path = join(args.savepath, 'rollout.json')
        json_data = {
            'Mean Return': np.mean(returns),
            'Std Return': np.std(returns),
            'Mean Success Rate': np.mean(successes),
        }
        json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)

if __name__ == '__main__':
    evaluation()
