import logging
import os

import hydra
from omegaconf import OmegaConf, dictconfig
import pytorch_lightning as pl
import torch

from offline_rl.data.labeled_merge_dataset import LabeledMergeDataset
from offline_rl.data.paired_trajectory_segment_dataset import PairedTrajectorySegmentDataset
from offline_rl.data.sample_batch_json_reader_dataset import SampleBatchJsonReaderDataset
from offline_rl.data.transition_shuffling_dataset import TransitionShufflingDataset
from offline_rl.scripts.rewards.learning.common import get_env, get_gym_model


def get_positive_unlabeled_data_loader(positive_filepath, unlabeled_filepath, config):
    positive_dataset = SampleBatchJsonReaderDataset(
        positive_filepath,
        debug_size=config.debug_size,
        debug_size_mode=config.get("debug_size_mode", "ordered"),
    )
    unlabeled_dataset = SampleBatchJsonReaderDataset(
        unlabeled_filepath,
        debug_size=config.debug_size,
        debug_size_mode=config.get("debug_size_mode", "ordered"),
    )
    dataset = LabeledMergeDataset([
        (1, positive_dataset),
        (0, unlabeled_dataset),
    ])
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=config.shuffle,
    )
    return data_loader


def get_positive_unlabeled_data_loaders(config):
    train_data_loader = get_positive_unlabeled_data_loader(
        config.positive_train_dataset_filepath,
        config.unlabeled_train_dataset_filepath,
        config,
    )
    val_data_loader = get_positive_unlabeled_data_loader(
        config.positive_val_dataset_filepath,
        config.unlabeled_val_dataset_filepath,
        config,
    )
    return train_data_loader, val_data_loader


def get_direct_regression_data_loader(filepath, config):
    dataset = SampleBatchJsonReaderDataset(
        filepath,
        debug_size=config.debug_size,
        debug_size_mode=config.get("debug_size_mode", "ordered"),
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=config.shuffle,
    )
    return data_loader


def get_direct_regression_data_loaders(config):
    train_data_loader = get_direct_regression_data_loader(config.train_dataset_filepath, config)
    val_data_loader = get_direct_regression_data_loader(config.val_dataset_filepath, config)
    return train_data_loader, val_data_loader


def get_preference_based_data_loader(filepath, config):
    dataset = PairedTrajectorySegmentDataset(
        filepath,
        segment_length=config.segment_length,
        max_num_pairs=config.max_num_pairs,
        debug_size=config.debug_size,
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=config.shuffle,
    )
    return data_loader


def get_preference_based_data_loaders(config):
    train_data_loader = get_preference_based_data_loader(config.train_dataset_filepath, config)
    val_data_loader = get_preference_based_data_loader(config.val_dataset_filepath, config)
    return train_data_loader, val_data_loader


def get_transition_shuffling_data_loader(filepath, config):
    dataset = TransitionShufflingDataset(
        filepath,
        num_pairs=config.num_pairs,
        debug_size=config.debug_size,
        debug_size_mode=config.get("debug_size_mode", "ordered"),
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=config.shuffle,
    )
    return data_loader


def get_transition_shuffling_data_loaders(config):
    train_data_loader = get_transition_shuffling_data_loader(config.train_dataset_filepath, config)
    val_data_loader = get_transition_shuffling_data_loader(config.val_dataset_filepath, config)
    return train_data_loader, val_data_loader


def get_data_loaders(config):
    if config.dataset_type == "positive_unlabeled":
        return get_positive_unlabeled_data_loaders(config)
    elif config.dataset_type == "direct_regression":
        return get_direct_regression_data_loaders(config)
    elif config.dataset_type == "preference_based":
        return get_preference_based_data_loaders(config)
    elif config.dataset_type == "transition_shuffling":
        return get_transition_shuffling_data_loaders(config)
    else:
        raise ValueError(f"Invalid dataset type: {config.dataset_type}")


def train(train_loader, val_loader, model, config):
    logging.info("Start Training...")
    trainer_args = OmegaConf.to_container(config.trainer_args, resolve=True)
    trainer_args["default_root_dir"] = config.output_dir
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=config.output_dir,
        filename="{epoch}-{val_loss:.8f}",
        save_top_k=10,
        mode="min",
        monitor="val_loss",
        verbose=False,
    )
    early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", patience=4)
    trainer = pl.Trainer(**trainer_args, callbacks=[checkpoint_callback, early_stopping_callback])
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
    logging.info("Model training complete.")
    return checkpoint_callback.best_model_path


@hydra.main(config_path="configs")
def main(config: dictconfig.DictConfig) -> None:
    """Loads data loaders, the environment, the model, and runs training.

    Since this doesn't have a default config you have to specify the config at the command line. For example:
    `python run_gym_reward_learning.py +positive_unlabeled/envs=gym_bouncing_balls_env <... other command line options>`
    Where `gym_bouncing_balls_env.yaml` is a config file that exists in the directory `positive_unlabeled/envs`.
    """
    os.makedirs(config.training.output_dir, exist_ok=True)
    # Save the config to the output directory for later reference.
    config_filepath = os.path.join(config.training.output_dir, "config.yaml")
    OmegaConf.save(config, config_filepath)

    train_loader, val_loader = get_data_loaders(config.data)
    env = get_env(config.data.env_name)
    model = get_gym_model(env.observation_space, env.action_space, config.model)
    return train(train_loader, val_loader, model, config.training)


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()
