import datetime
import os
import pickle
from typing import Tuple

import gym
import numpy as np
from tqdm import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter

import wrappers as w
from dataset_utils import D4RLDataset, reward_from_preference, reward_from_preference_transformer, split_into_trajectories
from evaluation import evaluate,original_evaluate
from learner import Learner

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.40'

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './runs/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('use_reward_model', False, 'Use reward model for relabeling reward.')
flags.DEFINE_string('model_type', 'MLP', 'type of reward model.')
flags.DEFINE_string('ckpt_dir',
                    './logs/pref_reward',
                    'ckpt path for reward model.')
flags.DEFINE_string('comment',
                    'base',
                    'comment for distinguishing experiments.')
flags.DEFINE_integer('seq_len', 25, 'sequence length for relabeling reward in Transformer.')
flags.DEFINE_bool('use_diff', False, 'boolean whether use difference in sequence for reward relabeling.')
flags.DEFINE_string('label_mode', 'last', 'mode for relabeling reward with tranformer.')

config_flags.DEFINE_config_file(
    'config',
    'default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)


def normalize(dataset, env_name, max_episode_steps=1000):
    trajs = split_into_trajectories(dataset.observations, dataset.actions,
                                    dataset.rewards, dataset.masks,
                                    dataset.dones_float,
                                    dataset.next_observations)
    trj_mapper = []
    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
        traj_len = len(traj)

        for _ in range(traj_len):
            trj_mapper.append((trj_idx, traj_len))

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    sorted_trajs = sorted(trajs, key=compute_returns)
    min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])

    normalized_rewards = []
    for i in range(dataset.size):
        _reward = dataset.rewards[i]
        if 'antmaze' in env_name:
            _, len_trj = trj_mapper[i]
            _reward -= min_return / len_trj
        _reward /= max_return - min_return
        _reward *= max_episode_steps
        normalized_rewards.append(_reward)

    dataset.rewards = np.array(normalized_rewards)


def make_env_and_dataset(env_name: str,
                         seed: int) :
    if 'metaworld' not in env_name:
        env = gym.make(env_name)

        env = w.EpisodeMonitor(env)
        env = w.SinglePrecision(env)

        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)

        dataset = D4RLDataset(env)
    else:
        import metaworld
        dataset_name = env_name.split('_')[1]
        import metaworld
        ml1 = metaworld.MT1(dataset_name, seed=1337)  # Construct the benchmark, sampling tasks

        env = ml1.train_classes[dataset_name]()  # Create an environment with task
        from gym import wrappers
        env = wrappers.TimeLimit(env, 500)
        env.train_tasks = ml1.train_tasks
        task = ml1.train_tasks[0]
        env.set_task(task)
        env._freeze_rand_vec = False

        dataset_tmp = np.load(
            './data/' + dataset_name + '/data_randgoal_08_50_08_batch.npy', allow_pickle=True).tolist()

        dataset = D4RLDataset(env, input_dataset=dataset_tmp)
    if FLAGS.use_reward_model:
        reward_model = initialize_model(FLAGS.env_name)
        print('\n', 'model type: ', FLAGS.model_type, '\n')
        if FLAGS.model_type == "MR":
            dataset = reward_from_preference(FLAGS.env_name, dataset, reward_model, batch_size=FLAGS.batch_size)
        else:
            dataset = reward_from_preference_transformer(
                FLAGS.env_name,
                dataset,
                reward_model,
                batch_size=FLAGS.batch_size,
                seq_len=FLAGS.seq_len,
                use_diff=FLAGS.use_diff,
                label_mode=FLAGS.label_mode
            )
        del reward_model

    if FLAGS.use_reward_model:
        if 'metaworld' not in env_name:
            normalize(dataset, FLAGS.env_name, max_episode_steps=env.env.env._max_episode_steps)
        else:
            normalize(dataset, FLAGS.env_name, max_episode_steps=env._max_episode_steps)
        if 'antmaze' in FLAGS.env_name:
            dataset.rewards -= 1.0
    else:
        if 'metaworld' in env_name:
            normalize(dataset, FLAGS.env_name, max_episode_steps=env._max_episode_steps)
        if 'antmaze' in FLAGS.env_name:
            dataset.rewards -= 1.0
    # with open('/data3/zj/PreferenceTransformer/learned_rewards/' + FLAGS.env_name + '10queries.npy', 'wb') as f:
    #     np.save(f, dataset.rewards)
    return env, dataset


def initialize_model(env_name):
    model_path = './saved_model/' + f"best_model_{env_name}.pkl"
    with open(model_path, "rb") as f:
        ckpt = pickle.load(f)
    print('reward model loaded...')
    reward_model = ckpt['reward_model']
    return reward_model


def main(_):
    discount = True
    # save_dir = f"runs/{FLAGS.env_name}_reward_model_{str(FLAGS.use_reward_model)}_discount_{str(discount)}_seed_{str(FLAGS.seed)}"
    save_dir = f"test/{FLAGS.env_name}_reward_model_{str(FLAGS.use_reward_model)}_discount_{str(discount)}_seed_{str(FLAGS.seed)}"

    summary_writer = SummaryWriter(save_dir, write_to_disk=True)
    os.makedirs(FLAGS.save_dir, exist_ok=True)

    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)
    mask_conditioner = np.ones_like(dataset.masks) * 0.5 / 0.99
    print('mask shape: ', mask_conditioner.shape)
    
    returns = []
    if 'metaworld' in FLAGS.env_name:
        for i in range(1000):
            returns.append(np.sum(dataset.rewards[i*500:(i+1)*500]))
        print('traj return: ', np.array(returns).shape)
        idx = np.argpartition(returns, -700)[-700:]
        for i in idx:
            mask_conditioner[i*500:(i+1)*500] = 1
    else:
        num_trajs = int(dataset.rewards.shape[0]/1000)
        top_trajs = int(num_trajs*0.7)
        for i in range(num_trajs):
            returns.append(np.sum(dataset.rewards[i*1000:(i+1)*1000]))
        idx = np.argpartition(returns, -top_trajs)[-top_trajs:]
        for i in idx:
            mask_conditioner[i*1000:(i+1)*1000] = 1
    
    if discount == False:
        mask_conditioner = np.ones_like(dataset.masks)
    else:
        pass

    dataset.masks = dataset.masks * mask_conditioner
    
    kwargs = dict(FLAGS.config)
    agent = Learner(FLAGS.seed,
                    env.observation_space.sample()[np.newaxis],
                    env.action_space.sample()[np.newaxis],
                    max_steps=FLAGS.max_steps,
                    **kwargs)

    eval_returns = []
    for i in tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm):
        batch = dataset.sample(FLAGS.batch_size)
        update_info = agent.update(batch)

        if i % FLAGS.log_interval == 0:
            for k, v in update_info.items():
                if v.ndim == 0:
                    summary_writer.add_scalar(f'training/{k}', v, i)
                else:
                    summary_writer.add_histogram(f'training/{k}', v, i)
            summary_writer.flush()

        if i % FLAGS.eval_interval == 0:
            eval_stats = evaluate(agent, env, FLAGS.eval_episodes) if 'metaworld' in FLAGS.env_name else original_evaluate(agent, env, FLAGS.eval_episodes)

            for k, v in eval_stats.items():
                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
            summary_writer.flush()

            eval_returns.append((i, eval_stats['return']))

if __name__ == '__main__':
    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
    app.run(main)
