import json
from os.path import join
from tqdm import tqdm
import numpy as np
import random


from src.policies.dt_policy import DTPolicy
import src.utils as utils
from src.envs.toy_car.toy_car import ToyCar
import wandb
import pickle
import torch


def segment(observations, terminals, max_path_length, observation_dim):
    """
        segment `observations` into trajectories according to `terminals`
    """
    assert len(observations) == len(terminals)
    # print(len(observations)) #10001

    trajectories = [[]]
    curr_len = 0
    for obs, term in zip(observations, terminals):
        trajectories[-1].append(obs)
        curr_len += 1
        if term.squeeze() or (curr_len >= max_path_length):
            trajectories.append([])
            curr_len = 0
    
    # print('traj: ', len(trajectories)) # 2675
    
    

    if len(trajectories[-1]) == 0:
        trajectories = trajectories[:-1]

    ## list of arrays because trajectories lengths will be different
    trajectories = [np.stack(traj, axis=0) for traj in trajectories]


    n_trajectories = len(trajectories)
    path_lengths = [len(traj) for traj in trajectories]

    ## pad trajectories to be of equal length
    trajectories_pad = np.ones((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) * (-1000)
   
    for i, traj in enumerate(trajectories):
        path_length = path_lengths[i]
        # print('path length: ', path_length)
        # print((traj.shape)) # (160, 1)
        trajectories_pad[i,:path_length, :] = traj.reshape(-1, observation_dim)
        

    return trajectories_pad, path_lengths


class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    visiable: str = 'whole'
    config: str = 'config.offline'


# seed, trained model, dataset

#######################
######## setup ########
#######################
config_defaults = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 300000, 
        "env_name": "diagnosis_RL_dqn",
        'batch_size': 64, #8,
        'learning_rate':0.0001,
        'n_layer':3,
        'embed_dim':128
        
      
    }
def evaluation():

    args = Parser().parse_args('bc_plan')


    utils.set_device(args.device)
    run = wandb.init(
        
        config=config_defaults,
        sync_tensorboard=False, #project=wandb.config.base['WANDB_PROJECT'],
        save_code=True,# name = args.exp_name
    )


    # if args.dataset == 'stop80':
        # artifact = run.use_artifact('your-folder/data_stop80:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop80.npz'
    # elif args.dataset == 'stop70':
    #     artifact = run.use_artifact('your-folder/data_stop70:v0', type='dataset')
    #     wandb_dataset_dir = artifact.download()
    #     data_file = wandb_dataset_dir +'/data_stop70.npz'
    # elif args.dataset == 'stop60':
        # artifact = run.use_artifact('your-folder/data_stop60:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop60.npz'
    # elif args.dataset == 'stop50':
        # artifact = run.use_artifact('your-folder/data_stop50:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop50.npz'
    # elif args.dataset == 'stop40':
        # artifact = run.use_artifact('your-folder/data_stop40:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop40.npz'
    # elif args.dataset == 'stop30':
    artifact = run.use_artifact('your-folder/data_stop30:v0', type='dataset')
    wandb_dataset_dir = artifact.download()
    data_file = wandb_dataset_dir +'/data_stop30.npz'
    # elif args.dataset == 'stop_random':
        # artifact = run.use_artifact('your-folder/data_stop_random:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop_random.npz'

    with open(data_file, 'rb') as f:
        data = dict(np.load(f))
  

    # constant
    # data_index = [781, 789, 797, 805, 811 ]
    # model_index = [1144, 1152, 1160, 1168, 1174 ]
    # state_index = [155, 152, 151, 153, 152 ]
    # stop_sign = 30

    # data_index = [782, 790, 802, 810, 816 ]
    # model_index = [1145, 1153, 1165, 1173, 1179 ]
    # state_index = [156, 153, 152, 154, 153 ]
    # stop_sign = 40

    # data_index = [518, 533, 548, 563, 578 ]
    # model_index = [878, 893, 908, 923, 938 ]
    # state_index = [105, 100, 99, 101, 100 ]
    # stop_sign = 50

    # data_index = [519, 534, 549, 564, 579 ]
    # model_index = [879, 894, 909, 924, 939 ]
    # state_index = [106, 101, 100, 102, 101 ]
    # stop_sign = 60

    # data_index = [520, 535, 550, 565, 580 ]
    # model_index = [880, 895, 910, 925, 940 ]
    # state_index = [106, 102, 101, 103, 102 ]
    # stop_sign = 70

    # data_index = [521, 536, 551, 566, 581 ]
    # model_index = [881, 896, 911, 926, 941 ]
    # state_index = [108, 103, 102, 104, 103 ]
    # stop_sign = 80

    # data_index = [522, 537, 552, 567, 582 ]
    # model_index = [882, 897, 911, 927, 942 ]
    # state_index = [109, 104, 103, 105, 104 ]
    # stop_sign = 80

    # whole
    # data_index = [647, 647, 653, 659, 661 ] # old
    # model_index = [987, 991, 997, 1003, 1005 ]
    # state_index = [130, 127, 126, 128, 127 ]
    # stop_sign = 30

    data_index = [663, 667, 669, 671, 675 ]
    model_index = [1007, 1011, 1013, 1015, 1019 ]
    state_index = [134, 129, 128, 130, 129 ]
    stop_sign = 30

    # data_index = [644, 649, 654, 660, 662 ] # old
    # model_index = [988, 993, 998, 1004, 1006 ]
    # state_index = [132, 128, 127, 129, 128 ]
    # stop_sign = 40


    # data_index = [664, 668, 670, 672, 672 ]
    # model_index = [1008, 1012, 1014, 1016, 1020 ]
    # state_index = [135, 130, 129, 131, 130 ]
    # stop_sign = 40


    # data_index = [66, 69, 72, 75, 79 ]
    # model_index = [83, 86, 88, 92, 96 ]
    # state_index = [16, 16, 16, 16, 16 ]
    # stop_sign = 50

    # data_index = [65, 68, 71, 74, 78 ]
    # model_index = [82, 85, 88, 91, 95 ]
    # state_index = [15, 15, 15, 15, 15 ]
    # stop_sign = 60


    # data_index = [4, 4, 4, 10, 12 ]
    # model_index = [7, 11, 12, 18, 21 ]
    # state_index = [0,0,0,0,0 ]
    # stop_sign = 70

    # data_index = [64, 67, 70, 73, 77 ]
    # model_index = [81, 84, 87, 90, 94 ]
    # state_index = [14,14,14,14,14 ]
    # stop_sign = 80

    # data_index = [76, 80, 80, 80, 80 ]
    # model_index = [93, 97, 98, 99, 100 ]
    # state_index = [17,17,17,17,17 ]
    # stop_sign = random

    seeds = [0, 10, 20, 30, 40]
    
    for i in range(5):
        seed = seeds[i]
        data_artifact = run.use_artifact('your-folder/data_config:v'+str(data_index[i]), type='config')
        loadpath = data_artifact.download()
        dataset = pickle.load(open(loadpath+'/data_config.pkl', 'rb')).make()

        print('dataset: ', dataset)
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[i]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(state_index[i]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(seed)+'.pt')

        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)

    # #######################
    # ####### models ########
    # #######################

    # args.logbase = args.logbase + 'dt/' # ~/SPLT-transformer/logs/bt/
    # args.exp_name = args.gpt_loadpath + '/' + args.exp_name
    # args.savepath = join(args.logbase, args.dataset, args.exp_name)
    # print('logbase: ', args.logbase)
    # print('dataset: ', args.dataset)
    # print('savepath: ', args.savepath)
    # # dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
    # #         'data_config.pkl')

    # # gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
    # #         epoch=args.gpt_epoch, device=args.device)

    # # #######################
    # # ####### dataset #######
    # # #######################
      


        random.seed(seed)
        np.random.seed(seed)
        env = ToyCar(stop_sign = stop_sign, seed = seed)

        discount = dataset.discount
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim
        max_return = dataset.get_max_return()

        print('max return: ', max_return)

        #######################
        ###### main loop ######
        #######################

        returns = []
        successes = []

        T = 40
        num_episodes = 100
        max_history = args.max_context_transitions

        gpt.eval()

        policy = DTPolicy(
            gpt,
            max_return * 0.85,
            observation_dim,
            action_dim,
            discount,
            max_history=max_history,
            device=args.device)

        vis_threshold = 25
        percent_reward = []
        obs_segmented, *_ = segment(data['observations'], data['dones'] , 40, 4)
        rewards_segmented, *_ = segment(data['rewards'], data['dones'] , 40, 1)
        label_rewards = []

        for i in tqdm(range(num_episodes)):
            observation = env.reset(testing=True)
            policy.reset()
            total_reward = 0

            observation = obs_segmented[i, 0, :]
            env.ego_x = observation[0]
            env.ego_vel = observation[1]
            env.other_x = observation[2]
            env.other_vel = observation[3]

            for t in range(T):

                action = policy(observation)

                ## execute action in environment
                next_observation, reward, terminal,crash, info, ego_x, other_x = env.test_step(False, action)
                # False: terminate early due to crush
                # True: no terminal


                if args.visiable == 'constant':
                    print('[plan/visiable] constant')
                    
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = next_observation[0] + vis_threshold
                        next_observation[3] = next_observation[1] 
                        print('next obeservation: ', next_observation) 
            
                elif args.visiable == 'zero_value':
                    print('[plan/visiable] zero value')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = 0
                        next_observation[3] = 0
                        print('next obeservation: ', next_observation)

                elif args.visiable == 'random_noise':
                    print('[train/visiable] random noise')
                    if next_observation[2] - next_observation[0] >= vis_threshold:
                        next_observation[2] = vis_threshold + 9 * random.random()
                        next_observation[3] = 10 * random.random()
                        print('next obeservation: ', next_observation)
                    
                

                policy.update_context(observation, action, reward)

                ## update return
                total_reward += reward

                if terminal:
                    returns.append(total_reward)
                    successes.append(info['success'])
                    # print(total_reward, t)
                    # print('ego_x: ', ego_x)
                    # print('other_x: ', other_x)
                    # print(ego_x/other_x)
                    percent_reward.append(ego_x/other_x)
                    break

                observation = next_observation
        print('Percent Reward {0}'.format(np.mean(percent_reward)))
        print('Mean Return {0}'.format(np.mean(returns)))
        print('Std Return {0}'.format(np.std(returns)))
        print('Mean Success Rate {0}'.format(np.mean(successes)))

        utils.serialization.mkdir(args.savepath)
        json_path = join(args.savepath, 'rollout.json')
        json_data = {
            'Mean Return': np.mean(returns),
            'Std Return': np.std(returns),
            'Mean Success Rate': np.mean(successes),
        }
        json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)

if __name__ == '__main__':
    evaluation()
