"""
Train long range state prediction models from offline datasets (e.g. D4RL).

Dynamics model implemented:
- MLP based dynamics model (regular)
- Koopman dynamcis model

"""

import os
import random

import numpy as np
import argparse
import wandb
from tqdm import tqdm
import time
import logging
from koopman.agents import PredictiveLearner, predictive_config
from koopman.data import ReplayBuffer
from koopman.evaluation import get_predictive_model_eval_stats, get_predictive_model_eval_stats_ramdom_seq_sampling
from koopman.utils.envs import make_env_and_dataset
import d4rl


logger = logging.getLogger()
class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()
logger.addFilter(CheckTypesFilter())


def build_parser(init):
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--env_name', type=str,
        default='hopper-expert-v2', help='Environment name.')
    parser.add_argument(
        '--dataset_name', type=str,
        default='d4rl', choices=['d4rl'], help='Dataset name.')
    parser.add_argument(
        '--save_dir', type=str,
        default='/tmp/koopman_dynamics_model', help='Save directory.')
    parser.add_argument(
        '--save_learner', type=int,
        default=0,
        help='Whether to save the learner')
    parser.add_argument(
        '--seed', type=int,
        default=42, help='Random seed.')
    parser.add_argument(
        '--encoder_lr', type=float,
        default=1e-3, help='Encoder learning rate.')
    parser.add_argument(
        '--dynamics_lr', type=float,
        default=1e-3, help='Dynamics learning rate.')
    parser.add_argument(
        '--eval_episodes', type=int,
        default=10,
        help='Number of episodes used for evaluation.')
    parser.add_argument(
        '--log_interval', type=int,
        default=1000, help='Logging interval.')
    parser.add_argument(
        '--feat_dim', type=int,
        default=256, help='Logging interval.')
    parser.add_argument(
        '--eval_interval', type=int,
        default=10000, help='Eval interval.')
    parser.add_argument(
        '--batch_size', type=int,
        default=256, help='Mini batch size.')
    parser.add_argument(
        '--max_steps', type=int,
        default=int(5e5), help='Number of training steps.')
    parser.add_argument(
        '--updates_per_step', type=int,
        default=1, help='updates per step.')
    parser.add_argument(
        '--train_frac', type=float,
        default=0.8,
        help='Fraction of the dataset to use for training.')
    parser.add_argument(
        '--compute_mc_returns', type=bool,
        default=True, help='Use tqdm progress bar.')
    parser.add_argument(
        '--dynamics_model_type', type=str,
        default='regular', help='define dynamics model types 1) regular 2) dense_koopman 3) diagonal_koopman 4)scan 5) lifted_diagonal_koopman 6) original_diagonal_koopman')
    parser.add_argument(
        '--replay_buffer_size', type=int,
        default=int(5e6), help='Size of the buffer')
    parser.add_argument(
        '--train_seq_length', type=int,
        default=50, help='Length of state sequence.')
    parser.add_argument(
        '--test_seq_lengths', type=list,
        default=[10, 50, 100, 200], help='Test predictions on these sequence lengths.')
    parser.add_argument(
        '--use_wandb', type=int,
        default=0, help='Use wand for logging.')
    parser.add_argument(
        '--tqdm', type=int,
        default=1, help='Use tqdm progress bar.')
    parser.add_argument(
        '--wandb_project', type=str,
        default='jax-dynamics-model-koopman', help='Name of the wand project')
    parser.add_argument(
        '--wandb_entity', type=str,
        default='anonymous_team', help='The entity or team of the wand project')
    parser.add_argument(
        '--save_video', type=int,
        default=0, help='Save videos during evaluation.')

    # Argument for Koopman model
    parser.add_argument(
        '--use_state_prediction_loss', type=int,
        default=0, help='Use this to get exact Koompan dynamics'
    )
    parser.add_argument(
        '--pred_reward', type=int,
        default=1, help='whether to predict reward'
    )
    parser.add_argument(
        '--state_emb_dim', type=int,
        default=512, help='State dimension.'
    )
    parser.add_argument(
        '--action_emb_dim', type=int,
        default=128, help='Action embedding dimension.'
    )
    parser.add_argument(
        '--discretize', type=int,
        default=1, help='Discretize the state space.')
    parser.add_argument(
        '--koopman_real_init_type', type=str,
        default='constant', help='Type of initialization for real part of Koopman matrix'
                                 'options: 1) constant 2) learnable')
    parser.add_argument(
        '--koopman_im_init_type', type=str,
        default='increasing_freq', help='Type of initialization for imaginary part of '
                                        'Koopman matrix options: 1)increasing_freq 2) random)')
    parser.add_argument(
        '--koopman_real_init_value', type=float,
        default=-0.2, help='Value of initialization for real part of Koopman matrix')
    parser.add_argument(
        '--dropout_rate', type=float,
        default=None, help='Dropout for both parts of the Koopman matrix and L operator')
    args, _ = parser.parse_known_args()
    return args


def merge_with_args(ray_config):
    args = build_parser(init=None)
    if ray_config is None:
        return args
    for key, val in ray_config.items():
        if key == "hidden_dims":
            setattr(args, "feat_dim", val)
        if key in args.__dict__.keys():
            setattr(args, key, val)
    return args


def resolve_config_from_args(config, args):
    for key, val in config.items():
        if key in vars(args).keys():
            config[key] = vars(args)[key]
        else:
            setattr(args, key, val)
    return config, args


def build_env_and_dataset(env_name, dataset_name, train_frac,
                          seed=42, video_save_folder=None):
    """Build environment and dataset."""
    env, dataset = make_env_and_dataset(env_name, seed, dataset_name, video_save_folder)

    train_dataset, valid_dataset = dataset.get_train_validation_split(train_fraction=train_frac)

    print("Train dataset size: ", len(train_dataset.observations))
    print("Valid dataset size: ", len(valid_dataset.observations))

    return env, train_dataset, valid_dataset


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)


def build_learner_and_buffer(env, args):
    config = predictive_config()
    config, args = resolve_config_from_args(config, args)
    buffer_size = config.pop('replay_buffer_size')
    learner = PredictiveLearner(
            args.seed,
            env.observation_space.sample()[np.newaxis].repeat(args.train_seq_length, axis=0),
            env.action_space.sample()[np.newaxis].repeat(args.train_seq_length, axis=0),
            **config
    )

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        dir=args.save_dir
    )

    print('Size of the buffer: ', buffer_size)
    train_replay_buffer = ReplayBuffer(env.observation_space, env.action_space, buffer_size)
    test_replay_buffer = ReplayBuffer(env.observation_space, env.action_space, buffer_size)

    return learner, train_replay_buffer, test_replay_buffer



def train_offline_predictive_model(learner, train_replay_buffer, test_replay_buffer, valid_dataset, env, args):
    print('Training predictive model on the offline dataset ..')
    train_stats = None
    for step in tqdm(range(1, args.max_steps), smoothing=0.1, disable=not args.tqdm):
        for _ in range(args.updates_per_step):
            batch = train_replay_buffer.sample_seq(args.batch_size, args.train_seq_length)
            update_info = learner.update(batch)
            train_stats = update_info

        if step % args.log_interval == 0:
            wandb.log(
                {'train': train_stats},
                step
            )

        if step % args.eval_interval == 0:
            # evaluate learner for multiple episodes
            # eval_stats = get_predictive_model_eval_stats(learner, valid_dataset, args.test_seq_lengths)
            eval_stats = get_predictive_model_eval_stats_ramdom_seq_sampling(learner, test_replay_buffer, args.pred_reward, args.test_seq_lengths)

            # log in wandb
            wandb.log(
                {'eval': eval_stats},
                step
            )

    return learner


def run_experiment(ray_config=None):
    """
        Training offline-rl with model-free learners.
            dataset: d4rl
            learner: sac, bc
            env: halfcheetah-v2, hopper-v2, ant-v2, walker-v2
            coverage: expert, random, medium
    """
    args = merge_with_args(ray_config)
    if args.use_wandb:
        print('Using wandb for logging.')
        os.environ["WANDB_MODE"] = "online"
    else:
        print('Wandb disable for logging.')
        os.environ["WANDB_MODE"] = "disabled"

    print(args)

    os.makedirs(args.save_dir, exist_ok=True)

    # define the video saving folder for evaluation
    video_save_folder = None if not args.save_video else os.path.join(args.save_dir, 'video', 'eval')

    # build learner, environment
    env, train_dataset, valid_datatset = build_env_and_dataset(
        args.env_name,
        args.dataset_name,
        train_frac=args.train_frac,
        seed=args.seed,
        video_save_folder=video_save_folder)

    set_seed(args.seed)

    learner, train_replay_buffer, test_replay_buffer = build_learner_and_buffer(env, args)
    train_replay_buffer.initialize_with_dataset(train_dataset)
    test_replay_buffer.initialize_with_dataset(valid_datatset)

    learner = train_offline_predictive_model(
        learner,
        train_replay_buffer,
        test_replay_buffer,
        valid_datatset,
        env,
        args
    )

    # save learner
    if args.save_learner:
        savepath = os.path.join(args.save_dir, 'final_learner')
        learner.save(savepath)


if __name__ == '__main__':
    run_experiment()