import os
from typing import Tuple

import gym
import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter

import wrappers
from dataset_utils import D4RLDataset, split_into_trajectories, MetaworldDataset, Dataset
from evaluation import evaluate
from learner import Learner
import time
from preprocessing.learned_reward import merge_dataset
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from preprocessing.metaworld_process import make_metaworld_env
import gzip
import pickle as pkl

CDS_DIR = "/data/cds_data"
FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('source_name', 'halfcheetah-random-v2', 'Source Environment name.')
flags.DEFINE_string('comment', 'data_sharing', 'Comment for the run')

flags.DEFINE_string('save_dir', './logs/', 'Tensorboard logging dir.')
flags.DEFINE_string('data_share', 'none', 'share type.')
flags.DEFINE_float('target_split', 1, 'amount of target data to use')
flags.DEFINE_float('source_split', 1, 'amount of target data to use')
# flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('seed', int(time.time()), 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_float('var_coeff', 3, 'PDS var coeff')
# flags.DEFINE_float('expectile', 7, 'iql expectile')
# flags.DEFINE_float('tau', 0.7, 'iql tau')
config_flags.DEFINE_config_file(
    'config',
    'default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)


def normalize(dataset):
    trajs = split_into_trajectories(dataset.observations, dataset.actions,
                                    dataset.rewards, dataset.masks,
                                    dataset.dones_float,
                                    dataset.next_observations)

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    trajs.sort(key=compute_returns)

    dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
    dataset.rewards *= 1000.0


def split_data(dataset, split=1., rest=True):
    if split < 1:
        assert split >= 0
        sample_index = np.random.choice(np.arange(dataset.size), int(dataset.size * split), replace=False)
        dataset.rewards = dataset.rewards[sample_index, ...]
        dataset.observations = dataset.observations[sample_index, ...]
        dataset.next_observations = dataset.next_observations[sample_index, ...]
        dataset.actions = dataset.actions[sample_index, ...]
        dataset.dones_float = dataset.dones_float[sample_index, ...]
        dataset.masks = dataset.masks[sample_index, ...]
        dataset.size = int(dataset.size * split)


def load_meta_world(env_ids: list):
    file_names = [f"metaworld_task{env_id}.pkl" for env_id in env_ids]
    merged_data = dict()
    for file_name in file_names:
        with gzip.open(os.path.join(CDS_DIR, file_name), 'rb') as f:
            data = pkl.load(f)
        for key, value in data.items():
            if key not in merged_data:
                merged_data[key] = value
            else:
                merged_data[key] = np.concatenate([merged_data[key], value])
    return merged_data


def make_env_dataset_metaworld(env_id: int, seed: int, writer: any):
    assert 0 <= env_id < 4
    source_ids = [0, 1, 2, 3]
    source_ids.remove(env_id)
    env = make_metaworld_env(env_id)
    print("env space shape", env.observation_space)
    dataset = MetaworldDataset(load_meta_world([env_id]))
    source_dataset = MetaworldDataset(load_meta_world(source_ids))

    merge_dataset(env, dataset, source_dataset, FLAGS.data_share, FLAGS.var_coeff, writer)
    return env, dataset


def make_env_and_dataset(env_name: str, target_split: int, source_env_name: str, source_split: int,
                         seed: int, writer: any) -> Tuple[gym.Env, Dataset]:
    if "metaworld" in env_name:
        env_id = int(env_name[-1])
        return make_env_dataset_metaworld(env_id, seed, writer)
    env = gym.make(env_name)

    env = wrappers.EpisodeMonitor(env)
    env = wrappers.SinglePrecision(env)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    source_env = gym.make(source_env_name)
    source_dataset = D4RLDataset(source_env)
    dataset = D4RLDataset(env)

    if 'antmaze' in FLAGS.env_name:
        dataset.rewards -= 1.0
        source_dataset.rewards -= 1.0
        # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
        # but I found no difference between (x - 0.5) * 4 and x - 1.0
    elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name
          or 'hopper' in FLAGS.env_name or 'hammer' in FLAGS.env_name):
        normalize(dataset)
        normalize(source_dataset)
    split_data(dataset, target_split, rest=False)
    split_data(source_dataset, source_split)
    merge_dataset(dataset, source_dataset, FLAGS.data_share, FLAGS.var_coeff, writer)

    return env, dataset


def main(_):
    if FLAGS.target_split == 1 and FLAGS.source_split == 1:
        run_name = f"{FLAGS.data_share}_{FLAGS.comment}_{FLAGS.source_name}_{str(FLAGS.seed % 1000000)}"
    else:
        run_name = f"{FLAGS.comment}_{FLAGS.data_share}_{FLAGS.source_name}_{str(FLAGS.target_split)}_{str(FLAGS.source_split)}_{str(FLAGS.seed)}"
    summary_writer = SummaryWriter(os.path.join(FLAGS.save_dir, 'tb', FLAGS.env_name, run_name),
                                   write_to_disk=True)
    os.makedirs(FLAGS.save_dir, exist_ok=True)

    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.target_split, FLAGS.source_name, FLAGS.source_split,
                                        FLAGS.seed, summary_writer)

    kwargs = dict(FLAGS.config)
    agent = Learner(FLAGS.seed,
                    env.observation_space.sample()[np.newaxis],
                    env.action_space.sample()[np.newaxis],
                    max_steps=FLAGS.max_steps,
                    **kwargs)

    eval_returns = []
    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
                       smoothing=0.1,
                       disable=not FLAGS.tqdm):
        batch = dataset.sample(FLAGS.batch_size)

        update_info = agent.update(batch)

        if i % FLAGS.log_interval == 0:
            for k, v in update_info.items():
                if v.ndim == 0:
                    summary_writer.add_scalar(f'training/{k}', v, i)
                else:
                    summary_writer.add_histogram(f'training/{k}', v, i)
            summary_writer.flush()

        if i % FLAGS.eval_interval == 0:
            eval_stats = evaluate(agent, env, FLAGS.eval_episodes, i, run_name)

            for k, v in eval_stats.items():
                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
            summary_writer.flush()

            eval_returns.append((i, eval_stats['return']))
            np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
                       eval_returns,
                       fmt=['%d', '%.1f'])


if __name__ == '__main__':
    app.run(main)
