import json
from os.path import join
from tqdm import tqdm
import numpy as np
import random

from src.policies.splt_bt_policy import SPLTBTPolicy
import src.utils as utils
from src.utils.serialization import mkdir
from src.envs.toy_car.toy_car import ToyCar
import wandb
import pickle
import torch




def segment(observations, terminals, max_path_length, observation_dim):
    """
        segment `observations` into trajectories according to `terminals`
    """
    assert len(observations) == len(terminals)
    # print(len(observations)) #10001

    trajectories = [[]]
    curr_len = 0
    for obs, term in zip(observations, terminals):
        trajectories[-1].append(obs)
        curr_len += 1
        if term.squeeze() or (curr_len >= max_path_length):
            trajectories.append([])
            curr_len = 0
    
    # print('traj: ', len(trajectories)) # 2675
    
    

    if len(trajectories[-1]) == 0:
        trajectories = trajectories[:-1]

    ## list of arrays because trajectories lengths will be different
    trajectories = [np.stack(traj, axis=0) for traj in trajectories]


    n_trajectories = len(trajectories)
    path_lengths = [len(traj) for traj in trajectories]

    ## pad trajectories to be of equal length
    trajectories_pad = np.ones((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) * (-1000)
   
    for i, traj in enumerate(trajectories):
        path_length = path_lengths[i]
        # print('path length: ', path_length)
        # print((traj.shape)) # (160, 1)
        trajectories_pad[i,:path_length, :] = traj.reshape(-1, observation_dim)
        

    return trajectories_pad, path_lengths

class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    config: str = 'config.offline'

#######################
######## setup ########
#######################
config_defaults = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 300000,
        "env_name": "diagnosis_RL_dqn",
        'batch_size': 64, #8,
        'learning_rate':0.0001,
        'n_layer':3,
        'embed_dim':128


    }


def evaluation():
    args = Parser().parse_args('plan')

    utils.set_device(args.device)
    run = wandb.init(

        config=config_defaults,
        sync_tensorboard=False, #project=wandb.config.base['WANDB_PROJECT'],
        save_code=True,# name = args.exp_name
    )
     # if args.dataset == 'stop80':
        # artifact = run.use_artifact('your-folder/data_stop80:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop80.npz'
    # elif args.dataset == 'stop70':
        # artifact = run.use_artifact('your-folder/data_stop70:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop70.npz'
    # elif args.dataset == 'stop60':
        # artifact = run.use_artifact('your-folder/data_stop60:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop60.npz'
    # elif args.dataset == 'stop50':
        # artifact = run.use_artifact('your-folder/data_stop50:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop50.npz'
    # elif args.dataset == 'stop40':
        # artifact = run.use_artifact('your-folder/data_stop40:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop40.npz'
    # elif args.dataset == 'stop30':
    artifact = run.use_artifact('your-folder/data_stop30:v0', type='dataset')
    wandb_dataset_dir = artifact.download()
    data_file = wandb_dataset_dir +'/data_stop30.npz'
    # elif args.dataset == 'stop_random':
        # artifact = run.use_artifact('your-folder/data_stop_random:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop_random.npz'

    with open(data_file, 'rb') as f:
        data = dict(np.load(f))

    data_index = [666, 674, 682, 689, 692 ]
    model_index = [1010, 1018, 1028, 1035, 1039 ]
    state_index = [137, 133, 133, 135, 134 ]
    stop_sign = 30

    # data_index = [665, 673, 680, 688, 691 ]
    # model_index = [1009, 1017, 1026, 1034, 1038 ]
    # state_index = [136, 131, 132, 134, 133 ]
    # stop_sign = 40


    # data_index = [82, 84, 86, 86, 86 ]
    # model_index = [102, 104, 106, 107, 108 ]
    # state_index = [19, 19, 19, 19, 19 ]
    # stop_sign = 50

    # data_index = [87, 87, 89, 91, 93 ]
    # model_index = [109, 110, 112, 114, 116 ]
    # state_index = [20, 20, 20, 20, 20 ]
    # stop_sign = 60


    # data_index = [6, 11, 15, 21, 26 ]
    # model_index = [9, 19, 25, 32, 39 ]
    # state_index = [2,3,3,4,5 ]
    # stop_sign = 70

    # data_index = [31, 39, 43, 45, 52 ]
    # model_index = [44, 53, 57, 59, 66 ]
    # state_index = [9,9,9,9,10 ]
    # stop_sign = 80

    # data_index = [695, 695, 697, 699, 701 ]
    # model_index = [1048, 1049, 1051, 1054, 1055 ]
    # state_index = [140,135,134,136,135 ]
    # stop_sign = 80

    seeds = [0, 10, 20, 30, 40]

    for i in range(5):
        seed = seeds[i]
        data_artifact = run.use_artifact('your-folder/data_config:v'+str(data_index[i]), type='config')
        loadpath = data_artifact.download()
        dataset = pickle.load(open(loadpath+'/data_config.pkl', 'rb')).make()

        print('dataset: ', dataset)
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[i]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(state_index[i]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(seed)+'.pt')

        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)

        # #######################
        # ####### models ########
        # #######################

        # args.logbase = args.logbase + 'splt_bt/'
        # args.exp_name = args.gpt_loadpath + '/' + args.exp_name
        # args.savepath = join(args.logbase, args.dataset, args.exp_name)
        # dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
        #         'data_config.pkl')

        # gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
        #         epoch=args.gpt_epoch, device=args.device)

    #######################
    ####### dataset #######
    #######################




        random.seed(seed)
        np.random.seed(seed)
        env = ToyCar(stop_sign = stop_sign, seed = seed)

        discount = dataset.discount
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim

        #######################
        ###### main loop ######
        #######################

        returns = []
        successes = []

        T = 40
        num_episodes = 100
        max_history = args.max_context_transitions

        gpt.eval()

        policy = SPLTBTPolicy(
            gpt,
            args.horizon,
            observation_dim,
            action_dim,
            discount,
            max_history=max_history,
            device=args.device,
            agg='min',
        )

        percent_reward = []
        obs_segmented, *_ = segment(data['observations'], data['dones'] , 40, 4)
        rewards_segmented, *_ = segment(data['rewards'], data['dones'] , 40, 1)
        label_rewards = []
        for i in tqdm(range(num_episodes)):
            traj_path = join(args.savepath, 'traj_{0:04d}'.format(i))
            mkdir(traj_path)
            observation = env.reset(testing=True)
            policy.reset()
            total_reward = 0

            observation = obs_segmented[i, 0, :]
            env.ego_x = observation[0]
            env.ego_vel = observation[1]
            env.other_x = observation[2]
            env.other_vel = observation[3]

            for t in range(T):

                action, sequence, candidates, world_index, policy_index = policy(observation, max_horizon=T-t, return_plans=True)

                ## execute action in environment
                next_observation, reward, terminal,crash, info, ego_x, other_x = env.test_step(False, action)
                # False: terminate early due to crush
                # True: no terminal

                # saving predictions for plotting
                gt_sequence = np.concatenate([observation, action, [reward]])
                filename = join(traj_path, '{0:04d}.npz'.format(t))
                np.savez(
                    filename,
                    plan=sequence,
                    candidates=candidates,
                    gt=gt_sequence,
                    world_index=world_index,
                    policy_index=policy_index,
                    world_dim=gpt.world_latent_dim,
                    policy_dim=gpt.policy_latent_dim)

                ## update return
                total_reward += reward

                if terminal:
                    returns.append(total_reward)
                    successes.append(info['success'])
                    percent_reward.append(ego_x/other_x)
                    break

                observation = next_observation

        print('Percent Reward {0}'.format(np.mean(percent_reward)))
        print('Mean Return {0}'.format(np.mean(returns)))
        print('Std Return {0}'.format(np.std(returns)))
        print('Mean Success Rate {0}'.format(np.mean(successes)))

        utils.serialization.mkdir(args.savepath)
        json_path = join(args.savepath, 'rollout.json')
        json_data = {
            'Mean Return': np.mean(returns),
            'Std Return': np.std(returns),
            'Mean Success Rate': np.mean(successes),
        }
        json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)
if __name__ == '__main__':
    evaluation()
