import os
import pickle
from collections import defaultdict

import numpy as np

import transformers

import gym
import wrappers as w

import absl.app
import absl.flags
from flax.training.early_stopping import EarlyStopping
from flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel
from flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel

# import robosuite as suite
# from robosuite.wrappers import GymWrapper
# import robomimic.utils.env_utils as EnvUtils

from .sampler import TrajSampler
from .jax_utils import batch_to_jax
import JaxPref.reward_transform as r_tf
from .model import FullyConnectedQFunction
from viskit.logging import logger, setup_logger
from .MR import MR
from .replay_buffer import get_d4rl_dataset, index_batch
# import d4rl
from .NMR import NMR
from .PrefTransformer import PrefTransformer
from .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle

from ml_collections import config_flags

# Jax memory
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'

FLAGS_DEF = define_flags_with_default(
    env='halfcheetah-medium-v2',
    model_type='MLP',
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    early_stop=False,
    min_delta=1e-3,
    patience=10,

    reward_scale=1.0,
    reward_bias=0.0,
    clip_action=0.999,

    reward_arch='256-256',
    orthogonal_init=False,
    activations='relu',
    activation_final='none',
    training=True,

    n_epochs=2000,
    eval_period=5,

    data_dir='./human_label',
    num_query=1000,
    query_len=25,
    skip_flag=0,
    balance=False,
    topk=10,
    window=2,
    use_human_label=False,
    feedback_random=False,
    feedback_uniform=False,
    enable_bootstrap=False,

    comment='',
    
    robosuite=False,
    robosuite_dataset_type="ph",
    robosuite_dataset_path='./data',
    robosuite_max_episode_steps=500,

    reward=MR.get_default_config(),
    transformer=PrefTransformer.get_default_config(),
    lstm=NMR.get_default_config(),
    logging=WandBLogger.get_default_config(),
)

config_flags.DEFINE_config_file(
    'config',
    'default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)

def main(_):
    FLAGS = absl.flags.FLAGS

    variant = get_user_flags(FLAGS, FLAGS_DEF)

    save_dir = FLAGS.logging.output_dir + '/' + FLAGS.env
    save_dir += '/' + str(FLAGS.model_type) + '/'
    
    FLAGS.logging.group = f"{FLAGS.env}_{FLAGS.model_type}"
    assert FLAGS.comment, "You must leave your comment for logging experiment."
    FLAGS.logging.group += f"_{FLAGS.comment}"
    FLAGS.logging.experiment_id = FLAGS.logging.group + f"_s{FLAGS.seed}"
    save_dir += f"{FLAGS.comment}" + "/"
    save_dir += 's' + str(FLAGS.seed)

    model_save_dir = './saved_model'

    setup_logger(
        variant=variant,
        seed=FLAGS.seed,
        base_log_dir=save_dir,
        include_exp_prefix_sub_dir=False
    )

    FLAGS.logging.output_dir = save_dir
    wb_logger = WandBLogger(FLAGS.logging, variant=variant)

    set_random_seed(FLAGS.seed)

    if 'metaworld' in FLAGS.env:
        print('metaworkd task')
        import metaworld
        dataset_name = FLAGS.env.split('_')[1]
        import metaworld
        ml1 = metaworld.MT1(dataset_name, seed=1337)  # Construct the benchmark, sampling tasks

        gym_env = ml1.train_classes[dataset_name]()  # Create an environment with task
        from gym import wrappers
        gym_env = wrappers.TimeLimit(gym_env,500)
        # print(ml1.train_tasks)
        gym_env.train_tasks = ml1.train_tasks
        # task = 0
        task = ml1.train_tasks[0]
        gym_env.set_task(task)
        gym_env._freeze_rand_vec = False
        dataset = np.load(
            '/mnt/yyq/data/' + dataset_name + '/data_randgoal_08_50_08_batch.npy', allow_pickle=True).tolist()
        print('dataset has loaded...')
        label_type = 0
        eval_sampler = TrajSampler(gym_env, 500)

    if 'ant' in FLAGS.env:
        gym_env = gym.make(FLAGS.env)
        gym_env = w.EpisodeMonitor(gym_env)
        gym_env = w.SinglePrecision(gym_env)
        gym_env.seed(FLAGS.seed)
        gym_env.action_space.seed(FLAGS.seed)
        gym_env.observation_space.seed(FLAGS.seed)
        dataset = r_tf.qlearning_ant_dataset(gym_env)
        label_type = 0

    dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)
    # use fixed seed for collecting segments.
    set_random_seed(FLAGS.data_seed)
    
    print("load saved indices.")
    if 'dense' in FLAGS.env:
        env = "-".join(FLAGS.env.split("-")[:-2] + [FLAGS.env.split("-")[-1]])
    elif FLAGS.robosuite:
        env = f"{FLAGS.env}_{FLAGS.robosuite_dataset_type}"
    else:
        env = FLAGS.env

    base_path = os.path.join(FLAGS.data_dir, env)
    
    reward_model = None
    pref_dataset = None
    num_query = FLAGS.num_query
    for query_index in range(num_query-3):
        query_index += 1

        if os.path.exists(base_path):
            import shutil
            shutil.rmtree(base_path)
        pref_dataset = r_tf.get_queries_from_multi(
            gym_env, dataset, FLAGS.num_query, FLAGS.query_len, reward_model, pref_dataset, query_index,
            data_dir=base_path, label_type=label_type, balance=FLAGS.balance)

        human_indices_2_file, human_indices_1_file, script_labels_file = sorted(os.listdir(base_path))
        with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:   # Unpickling
            human_indices = pickle.load(fp)
        with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:   # Unpickling
            human_indices_2 = pickle.load(fp)
        with open(os.path.join(base_path, script_labels_file), "rb") as fp:   # Unpickling
            human_labels = pickle.load(fp)
        true_eval = True if len(human_labels) > FLAGS.num_query else False
        # pref_eval_dataset = r_tf.load_queries_with_indices(
        #     gym_env, dataset, int(FLAGS.num_query * 0.1), FLAGS.query_len,
        #     label_type=label_type, saved_indices=[human_indices, human_indices_2], saved_labels=human_labels,
        #     balance=FLAGS.balance)

        set_random_seed(FLAGS.seed)
        observation_dim = gym_env.observation_space.shape[0]
        action_dim = gym_env.action_space.shape[0]

        data_size = pref_dataset["observations"].shape[0]
        print('data size: ', data_size)

        interval = int(data_size / FLAGS.batch_size) + 1

        # eval_data_size = pref_eval_dataset["observations"].shape[0]
        # eval_interval = int(eval_data_size / FLAGS.batch_size) + 1

        early_stop = EarlyStopping(min_delta=FLAGS.min_delta, patience=FLAGS.patience)

        if FLAGS.model_type == "PrefTransformer":
            total_epochs = FLAGS.n_epochs
            config = transformers.GPT2Config(
                **FLAGS.transformer
            )
            config.warmup_steps = int(total_epochs * 0.1 * interval)
            config.total_steps = total_epochs * interval

            trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
            reward_model = PrefTransformer(config, trans)

        if FLAGS.model_type == "MR":
            train_loss = "reward/rf_loss"
        elif FLAGS.model_type == "NMR":
            train_loss = "reward/lstm_loss"
        elif FLAGS.model_type == "PrefTransformer":
            train_loss = "reward/trans_loss"

        criteria_key = None
        for epoch in range(FLAGS.n_epochs + 1):
            print('epoch: ', epoch, FLAGS.n_epochs)
            metrics = defaultdict(list)
            metrics['epoch'] = epoch
            if epoch:
                # train phase
                shuffled_idx = np.random.permutation(pref_dataset["observations"].shape[0])
                for i in range(interval):
                    start_pt = i * FLAGS.batch_size
                    end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset["observations"].shape[0])
                    with Timer() as train_timer:
                        # train
                        batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt]))
                        for key, val in prefix_metrics(reward_model.train(batch), 'reward').items():
                            metrics[key].append(val)
                metrics['train_time'] = train_timer()
            else:
                # for using early stopping with train loss.
                metrics[train_loss] = [float(FLAGS.query_len)]

            # # eval phase
            # if epoch % FLAGS.eval_period == 0:
            #     for j in range(eval_interval):
            #         eval_start_pt, eval_end_pt = j * FLAGS.batch_size, min((j + 1) * FLAGS.batch_size, pref_eval_dataset["observations"].shape[0])
            #         # batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
            #         batch_eval = batch_to_jax(index_batch(pref_eval_dataset, range(eval_start_pt, eval_end_pt)))
            #         for key, val in prefix_metrics(reward_model.evaluation(batch_eval), 'reward').items():
            #             metrics[key].append(val)
            #     if not criteria_key:
            #         if "antmaze" in FLAGS.env and not "dense" in FLAGS.env and not true_eval:
            #             # choose train loss as criteria.
            #             criteria_key = train_loss
            #         else:
            #             # choose eval loss as criteria.
            #             criteria_key = key
            #     criteria = np.mean(metrics[criteria_key])
            #     has_improved, early_stop = early_stop.update(criteria)
            #     if early_stop.should_stop and FLAGS.early_stop:
            #         for key, val in metrics.items():
            #             if isinstance(val, list):
            #                 metrics[key] = np.mean(val)
            #         logger.record_dict(metrics)
            #         logger.dump_tabular(with_prefix=False, with_timestamp=False)
            #         wb_logger.log(metrics)
            #         print('Met early stopping criteria, breaking...')
            #         break
            #     elif epoch > 0 and has_improved:
            #         metrics["best_epoch"] = epoch
            #         metrics[f"{key}_best"] = criteria
            #         save_data = {"reward_model": reward_model, "variant": variant, "epoch": epoch}
            #         # if query_index % 2 == 0:
            #         save_pickle(save_data, f"best_model_{FLAGS.env}_iter_{str(query_index)}.pkl", model_save_dir)
            #         print('save done...   epoch: ', epoch)
            
            for key, val in metrics.items():
                if isinstance(val, list):
                    metrics[key] = np.mean(val)
            logger.record_dict(metrics)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
            wb_logger.log(metrics)
        
        if query_index > 8:
            save_data = {"reward_model": reward_model, "variant": variant, "epoch": epoch}    
            save_pickle(save_data, f"best_model_{FLAGS.env}_iter_{str(query_index)}.pkl", model_save_dir)
            print('save done...   epoch: ', epoch)
        
    if FLAGS.save_model:
        save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch}
        save_pickle(save_data, f'model_{FLAGS.env}_iter_{str(query_index)}.pkl', model_save_dir)
        print('save done...   epoch: ', epoch)

if __name__ == '__main__':
    absl.app.run(main)
