import json
from os.path import join
from tqdm import tqdm
import numpy as np
import random


import src.utils as utils
from src.policies.tt_policy import TTPolicy
from src.utils.serialization import mkdir
from src.envs.toy_car.toy_car import ToyCar
import wandb
import pickle
import torch


def segment(observations, terminals, max_path_length, observation_dim):
    """
        segment `observations` into trajectories according to `terminals`
    """
    assert len(observations) == len(terminals)
    # print(len(observations)) #10001

    trajectories = [[]]
    curr_len = 0
    for obs, term in zip(observations, terminals):
        trajectories[-1].append(obs)
        curr_len += 1
        if term.squeeze() or (curr_len >= max_path_length):
            trajectories.append([])
            curr_len = 0
    
    
    

    if len(trajectories[-1]) == 0:
        trajectories = trajectories[:-1]

    ## list of arrays because trajectories lengths will be different
    trajectories = [np.stack(traj, axis=0) for traj in trajectories]


    n_trajectories = len(trajectories)
    path_lengths = [len(traj) for traj in trajectories]

    ## pad trajectories to be of equal length
    trajectories_pad = np.ones((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) * (-1000)
   
    for i, traj in enumerate(trajectories):
        path_length = path_lengths[i]
  
        trajectories_pad[i,:path_length, :] = traj.reshape(-1, observation_dim)
        

    return trajectories_pad, path_lengths

class Parser(utils.Parser):
    dataset: str = 'idm-uniform07'
    config: str = 'config.offline'

#######################
######## setup ########
#######################

config_defaults = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 300000, 
        "env_name": "diagnosis_RL_dqn",
        'batch_size': 64, #8,
        'learning_rate':0.0001,
        'n_layer':3,
        'embed_dim':128
        
      
    }


def evaluation():
    args = Parser().parse_args('tt_plan')

    utils.set_device(args.device)


    run = wandb.init(
        
        config=config_defaults,
        sync_tensorboard=False, #project=wandb.config.base['WANDB_PROJECT'],
        save_code=True,# name = args.exp_name
    )
      # if args.dataset == 'stop80':
    # artifact = run.use_artifact('your-folder/data_stop80:v0', type='dataset')
    # wandb_dataset_dir = artifact.download()
    # data_file = wandb_dataset_dir +'/data_stop80.npz'
    # elif args.dataset == 'stop70':
        # artifact = run.use_artifact('your-folder/data_stop70:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop70.npz'
    # elif args.dataset == 'stop60':
    artifact = run.use_artifact('your-folder/data_stop60:v0', type='dataset')
    wandb_dataset_dir = artifact.download()
    data_file = wandb_dataset_dir +'/data_stop60.npz'
    # elif args.dataset == 'stop50':
        # artifact = run.use_artifact('your-folder/data_stop50:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop50.npz'
    # elif args.dataset == 'stop40':
        # artifact = run.use_artifact('your-folder/data_stop40:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop40.npz'
    # elif args.dataset == 'stop30':
        # artifact = run.use_artifact('your-folder/data_stop30:v0', type='dataset')
        # wandb_dataset_dir = artifact.download()
        # data_file = wandb_dataset_dir +'/data_stop30.npz'
    # elif args.dataset == 'stop_random':
    # artifact = run.use_artifact('your-folder/data_stop_random:v0', type='dataset')
    # wandb_dataset_dir = artifact.download()
    # data_file = wandb_dataset_dir +'/data_stop_random.npz'

    with open(data_file, 'rb') as f:
        data = dict(np.load(f))

    # data_index = [842, 844, 844, 846, 848]
    # model_index = [1204, 1204, 1204, 1204, 1204]
    # state_index = [168, 163, 161, 163, 162]
    # stop_sign = 30
  

    # data_index = [841, 843, 845, 847, 849]
    # model_index = [1204, 1204, 1204, 1204, 1204]
    # state_index = [167, 162, 162, 164, 163]
    # stop_sign = 40


    # data_index = [615, 615, 630, 636, 640]
    # model_index = [972, 972, 984, 984, 984]
    # state_index = [127, 122, 121, 123, 122]
    # stop_sign = 50


    data_index = [594, 621, 631, 638, 641 ]
    model_index = [954, 978, 984, 984, 984 ]
    state_index = [128, 123, 122, 124, 123 ]
    stop_sign = 60


    # data_index = [595, 622, 632, 639, 642 ]
    # model_index = [954, 978, 984, 984, 984 ]
    # state_index = [129,124,123,125,124 ]
    # stop_sign = 70

    # data_index = [596, 608, 628, 633, 635 ]
    # model_index = [954, 965, 984, 984, 984 ]
    # state_index = [125,120,119,121,120 ]
    # stop_sign = 80

    # data_index = [597, 612, 629, 634, 631 ]
    # model_index = [954, 969, 984, 984, 984 ]
    # state_index = [126,121,120,122,121 ]
    # stop_sign = 80
    seeds = [0, 10, 20, 30, 40]
    for i in range(5):
        seed = seeds[i]
        data_artifact = run.use_artifact('your-folder/data_config:v'+str(data_index[i]), type='config')
        loadpath = data_artifact.download()
        dataset = pickle.load(open(loadpath+'/data_config.pkl', 'rb')).make()

        print('dataset: ', dataset)
        model_artifact = run.use_artifact('your-folder/model_config:v'+str(model_index[i]), type='config')
        config_path = model_artifact.download()
        config = pickle.load(open(config_path+'/model_config.pkl', 'rb'))

        state_artifact = run.use_artifact('your-folder/state_48_seed'+str(seed)+':v'+str(state_index[i]), type='model')
        state_path = state_artifact.download()
        state = torch.load(state_path+'/state_48_seed'+str(seed)+'.pt')


        gpt = config()
        gpt.to(args.device)
        gpt.load_state_dict(state, strict=True)
        print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
        print(config)

        




        #######################
        ####### dataset #######
        #######################

        random.seed(seed)
        np.random.seed(seed)
        env = ToyCar(stop_sign = stop_sign, seed = seed)

        discretizer = dataset.discretizer
        discount = dataset.discount
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim

    

        value_fn = lambda x: discretizer.value_fn(x, args.percentile)

        #######################
        ###### main loop ######
        #######################

        returns = []
        successes = []
        failures = []

        T = 40
        num_episodes = 100
        max_history = args.max_context_transitions

        gpt.eval()

        policy = TTPolicy(
            gpt,
            discretizer,
            args.horizon,
            args.beam_width,
            args.n_expand,
            value_fn,
            observation_dim,
            action_dim,
            discount,
            verbose=args.verbose,
            k_obs=args.k_obs,
            k_act=args.k_act,
            cdf_obs=args.cdf_obs,
            cdf_act=args.cdf_act,
            prefix_context=args.prefix_context,
            max_history=max_history,
            device=args.device)
        percent_reward = []
        obs_segmented, *_ = segment(data['observations'], data['dones'] , 40, 4)
        rewards_segmented, *_ = segment(data['rewards'], data['dones'] , 40, 1)
        label_rewards = []
        for i in tqdm(range(num_episodes)):
            traj_path = join(args.savepath, 'traj_{0:04d}'.format(i))
            mkdir(traj_path)
            observation = env.reset(testing=True)
            policy.reset()
            total_reward = 0

            observation = obs_segmented[i, 0, :]
            env.ego_x = observation[0]
            env.ego_vel = observation[1]
            env.other_x = observation[2]
            env.other_vel = observation[3]

            for t in range(T):

                action, sequence, candidates = policy(observation, max_horizon=T-t, return_plans=True)

                ## execute action in environment
                next_observation, reward, terminal,crash, info, ego_x, other_x = env.test_step(False, action)

                # saving predictions for plotting
                gt_sequence = np.concatenate([observation, action, [reward]])
                filename = join(traj_path, '{0:04d}.npz'.format(t))
                np.savez(
                    filename,
                    plan=sequence,
                    candidates=candidates,
                    gt=gt_sequence)


                ## update return
                total_reward += reward

                policy.update_context(observation, action, reward)

                if terminal:
                    returns.append(total_reward)
                    successes.append(info['success'])
                    # if crash == True:
                    #     failures.append(1)
                    # elif crash == False:
                    #     print(crash)
                    #     failures.append(0)
                    percent_reward.append(ego_x/other_x)
                    break

                observation = next_observation
        print('Percent Reward {0}'.format(np.mean(percent_reward)))
        print('Mean Return {0}'.format(np.mean(returns))) # 70:59.22772407494559, 60:
        print('Std Return {0}'.format(np.std(returns))) # 70:7.430284679723886. 60:
        print('Mean Success Rate {0}'.format(np.mean(successes))) # 1.0
        # print('Mean Failure Rate {0}'.format(np.mean(failures)))# 0.0

        utils.serialization.mkdir(args.savepath)
        json_path = join(args.savepath, 'rollout.json')
        json_data = {
            'Mean Return': np.mean(returns),
            'Std Return': np.std(returns),
            'Mean Success Rate': np.mean(successes),
        }
        json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)



if __name__ == '__main__':
    evaluation()
