"""Run the training."""

from __future__ import annotations
from ensurepip import bootstrap
from enum import auto

import json
import os
from turtle import window_height
from typing import Optional, Union

import configargparse
from d4rl import offline_env
from gcsl import envs
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import wandb
import mujoco_py

from rvs import analyze_d4rl, dataset, policies, step, util, visualize, preemption

args_filename = "args.json"
checkpoint_dir = "checkpoints"
rollout_dir = "env_steps_data"
wandb_project = "rvs"


def log_args(
    args: configargparse.Namespace,
    wandb_logger: pl.loggers.wandb.WandbLogger,
) -> None:
    """Log arguments to a file in the wandb directory."""
    wandb_logger.log_hyperparams(args)

    args.wandb_entity = wandb_logger.experiment.entity
    args.wandb_project = wandb_logger.experiment.project
    args.wandb_run_id = wandb_logger.experiment.id
    args.wandb_path = wandb_logger.experiment.path

    out_directory = wandb_logger.experiment.dir
    args_file = os.path.join(out_directory, args_filename)
    with open(args_file, "w") as f:
        try:
            json.dump(args.__dict__, f)
        except AttributeError:
            json.dump(args, f)


def run_training(
    profiler: str,
    env: Union[step.GCSLToGym, offline_env.OfflineEnv],
    env_name: str,
    dataset_preprocess: str,
    store_dataset_gpu: bool,
    percent_dataset: float,
    seed: int,
    wandb_logger: pl.loggers.wandb.WandbLogger,
    rollout_directory: Optional[str],
    unconditional_policy: bool,
    reward_conditioning: bool,
    discount_factor: float,
    bootstrap_iters: int,
    bootstrap_model: str,
    bootstrap_model_args: dict,
    bootstrap_threshold_D: float,
    bootstrap_noise: float,
    bootstrap_feature_extractor: str,
    cumulative_reward_to_go: bool,
    epochs: int,
    max_steps: int,
    train_time: str,
    hidden_size: int,
    depth: int,
    learning_rate: float,
    weight_decay: float,
    learning_rate_scheduler: str,
    auto_tune_lr: bool,
    dropout_p: float,
    obs_noise: float,
    checkpoint_every_n_epochs: int,
    checkpoint_every_n_steps: int,
    checkpoint_time_interval: str,
    batch_size: int,
    val_frac: float,
    use_gpu: bool,
    slurm_checkpoint_dir: Optional[str],
) -> None:
    """Run the training with PyTorch Lightning and log to Weights & Biases."""

    policy = policies.RvS(
        env.observation_space,
        env.action_space,
        hidden_size=hidden_size,
        depth=depth,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        learning_rate_scheduler=learning_rate_scheduler,
        dropout_p=dropout_p,
        obs_noise=obs_noise,
        batch_size=batch_size,
        unconditional_policy=unconditional_policy,
        reward_conditioning=reward_conditioning,
        env_name=env_name,
    )
    wandb_logger.watch(policy, log="all")

    monitor = "val_loss" if val_frac > 0 else "train_loss"
    checkpoint_dirpath = os.path.join(
        wandb_logger.experiment.dir, checkpoint_dir)
    checkpoint_filename = "gcsl-" + env_name + \
        "-{epoch:03d}-{" + f"{monitor}" + ":.4e}"
    periodic_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dirpath,
        filename=checkpoint_filename,
        save_last=False,
        save_top_k=-1,
        every_n_epochs=checkpoint_every_n_epochs,
        every_n_train_steps=checkpoint_every_n_steps,
        train_time_interval=pd.Timedelta(
            checkpoint_time_interval).to_pytimedelta()
        if checkpoint_time_interval is not None
        else None,
    )
    val_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dirpath,
        monitor=monitor,
        filename=checkpoint_filename,
        save_last=True,  # save latest model
        save_top_k=1,  # save top model based on monitored loss
    )

    callbacks = [periodic_checkpoint_callback, val_checkpoint_callback]

    if slurm_checkpoint_dir is not None:
        # Save a checkpoint every epoch to the slurm checkpoint directory
        slurm_checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=slurm_checkpoint_dir,
            filename=checkpoint_filename,
            save_last=True,
            save_top_k=-1,
            every_n_epochs=10,
        )
        callbacks.append(slurm_checkpoint_callback)

    resume_path = None
    if pm.exists('last.ckpt'):
        resume_path = pm._get_path('last.ckpt')
        print('Resuming from slurm checkpoint')
    trainer = pl.Trainer(
        gpus=int(use_gpu),
        auto_lr_find=auto_tune_lr,
        max_epochs=epochs,
        max_steps=max_steps,
        max_time=train_time,
        logger=wandb_logger,
        # progress_bar_refresh_rate=20,
        callbacks=callbacks,
        track_grad_norm=2,  # logs the 2-norm of gradients
        limit_val_batches=1.0 if val_frac > 0 else 0,
        limit_test_batches=0,
        resume_from_checkpoint=resume_path,
        profiler=profiler,
    )

    data_module = dataset.create_data_module(
        env,
        env_name,
        dataset_preprocess,
        store_dataset_gpu,
        percent_dataset,
        rollout_directory,
        batch_size=batch_size,
        val_frac=val_frac,
        unconditional_policy=unconditional_policy,
        reward_conditioning=reward_conditioning,
        discount_factor=discount_factor,
        bootstrap_iters=bootstrap_iters,
        bootstrap_model=bootstrap_model,
        bootstrap_model_args=bootstrap_model_args,
        bootstrap_threshold_D=bootstrap_threshold_D,
        bootstrap_noise=bootstrap_noise,
        bootstrap_feature_extractor=bootstrap_feature_extractor,
        average_reward_to_go=not cumulative_reward_to_go,
        seed=seed,
        wandb_run=wandb_logger.experiment,
        pm=pm.for_obj('data_module'),
    )

    if auto_tune_lr:
        trainer.tune(policy, datamodule=data_module)

    trainer.fit(policy, data_module)


if __name__ == "__main__":
    parser = configargparse.ArgumentParser(
        description="Reinforcement Learning via Supervised Learning",
    )
    # configuration
    parser.add_argument(
        "--configs",
        default=None,
        required=False,
        is_config_file=True,
        help="path(s) to configuration file(s)",
    )
    parser.add_argument(
        "--profiler",
        default=None,
        type=str,
        help="type of profiler"
    )
    # environment
    parser.add_argument(
        "--env_name",
        default="pointmass_rooms",
        type=str,
        choices=envs.env_names + step.gym_goal_envs + step.d4rl_env_names,
        help="an environment name",
    )
    # reproducibility
    parser.add_argument(
        "--seed",
        default=None,
        type=int,
        help="sets the random seed; if this is not specified, it is chosen randomly",
    )
    # conditioning
    conditioning_group = parser.add_mutually_exclusive_group()
    conditioning_group.add_argument(
        "--unconditional_policy",
        action="store_true",
        default=False,
        help="run vanilla behavior cloning without conditioning on goals",
    )
    conditioning_group.add_argument(
        "--reward_conditioning",
        action="store_true",
        default=False,
        help="condition behavior cloning on the reward-to-go",
    )
    parser.add_argument(
        "--bootstrap_iters",
        type=int,
        default=0,
        help="number of bootstrap iterations to run. If zero, no bootstrapping is performed.",
    )
    parser.add_argument(
        "--discount_factor",
        type=float,
        default=1.0,
        help="discount factor for reward-to-go",
    )
    parser.add_argument(
        "--cumulative_reward_to_go",
        action="store_true",
        default=False,
        help="if reward_conditioning, condition on cumulative (instead of average) "
        "reward-to-go",
    )
    # architecture
    parser.add_argument(
        "--learning_rate",
        type=float,
        required=True,
        help="learning rate for each gradient step",
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.01,
        help="weight decay for each gradient step",
    )
    parser.add_argument(
        "--auto_tune_lr",
        action="store_true",
        default=False,
        help="have PyTorch Lightning try to automatically find the best learning rate",
    )
    parser.add_argument(
        "--learning_rate_scheduler",
        type=str,
        default="constant",
        choices=["cosine", "linear", "constant"],
        help="learning rate scheduler",
    )
    parser.add_argument(
        "--hidden_size",
        type=int,
        required=True,
        help="size of hidden layers in policy network",
    )
    parser.add_argument(
        "--depth",
        type=int,
        required=True,
        help="number of hidden layers in policy network",
    )
    parser.add_argument(
        "--dropout_p",
        type=float,
        required=True,
        help="dropout probability",
    )
    parser.add_argument(
        "--obs_noise",
        type=float,
        default=0.0,
        help="standard deviation of observation noise",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        required=True,
        help="batch size for each gradient step",
    )
    # training
    train_time_group = parser.add_mutually_exclusive_group(required=True)
    train_time_group.add_argument(
        "--epochs",
        default=None,
        type=int,
        help="the number of training epochs.",
    )
    train_time_group.add_argument(
        "--max_steps",
        default=None,
        type=int,
        help="the number of training gradient steps per bootstrap iteration. ignored "
        "if --train_time is set",
    )
    train_time_group.add_argument(
        "--train_time",
        default=None,
        type=str,
        help="how long to train, specified as a DD:HH:MM:SS str",
    )
    checkpoint_frequency_group = parser.add_mutually_exclusive_group(
        required=True)
    checkpoint_frequency_group.add_argument(
        "--checkpoint_every_n_epochs",
        default=None,
        type=int,
        help="the period of training epochs for saving checkpoints",
    )
    checkpoint_frequency_group.add_argument(
        "--checkpoint_every_n_steps",
        default=None,
        type=int,
        help="the period of training gradient steps for saving checkpoints",
    )
    checkpoint_frequency_group.add_argument(
        "--checkpoint_time_interval",
        default=None,
        type=str,
        help="how long between saving checkpoints, specified as a HH:MM:SS str",
    )
    parser.add_argument(
        "--val_frac",
        type=float,
        required=True,
        help="fraction of data to use for validation",
    )
    parser.add_argument(
        "--use_gpu",
        action="store_true",
        default=False,
        help="place networks and data on the GPU",
    )
    parser.add_argument("--which_gpu", default=0,
                        type=int, help="which GPU to use")
    # GCSL
    parser.add_argument(
        "--rollout_directory",
        default=None,
        type=str,
        help="a directory containing the offline dataset to use for training",
    )
    parser.add_argument(
        "--total_steps",
        default=100000,
        type=int,
        help="if `rollout_directory` is not provided and the environment is from GCSL, "
        "generate an offline training dataset with this many environment steps",
    )
    parser.add_argument(
        "--max_episode_steps",
        default=50,
        type=int,
        help="the maximum number of steps in each episode",
    )
    discretization_group = parser.add_mutually_exclusive_group()
    discretization_group.add_argument(
        "--discretize",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="if the environment is from GCSL, discretize the environment's action "
        "space",
    )
    discretization_group.add_argument(
        "--discretize_rollouts_only",
        action="store_true",
        default=False,
        help="if the environment is from GCSL, sample discretized random rollouts in a "
        "continuous action space",
    )
    # analysis
    parser.add_argument(
        "--run_tag",
        default=None,
        type=str,
        help="a tag that's logged to help find the run later",
    )
    parser.add_argument(
        "--visualize",
        action="store_true",
        default=False,
        help="visualize the hitting times of each bootstrap iteration's learned policy",
    )
    parser.add_argument(
        "--analyze_d4rl",
        action="store_true",
        default=False,
        help="analyze the learned policy in D4RL",
    )
    parser.add_argument(
        "--trajectory_samples",
        default=None,
        type=int,
        help="number of trajectory samples for --visualize and --analyzed4rl flags",
    )
    parser.add_argument(
        "--val_checkpoint_only",
        action="store_true",
        default=False,
        help="pass --val_checkpoint_only to analyze_d4rl script",
    )
    parser.add_argument(
        "--d4rl_analysis",
        default="all",
        type=str,
        choices=[
            "input_interpolation",
            "weight_histograms",
            "kitchen_subtasks",
            "elite_goals",
            "length_goals",
            "reward_goals",
            "all",
        ],
        help="which analysis to run for --analyzed4rl",
    )
    # Comma separated list or None
    parser.add_argument(
        "--reward_targets",
        default=None,
        type=str,
        help="comma separated list of reward targets or None to use default based on normalized scores"
    )

    # Preemption
    parser.add_argument(
        "--checkpoint_dir",
        default=None,
        type=str,
        help="Slurm directory for checkpoints",
    )

    # Bootstrap Args
    parser.add_argument(
        "--bts_hidden_size",
        default=256,
        type=int,
        help="size of hidden layers in the return predictor network",
    )
    parser.add_argument(
        "--bts_n_layers",
        default=2,
        type=int,
        help="number of hidden layers in the return predictor network",
    )
    parser.add_argument(
        "--bts_n_bins",
        default=301,
        type=int,
        help="number of bins to discretize returns",
    )
    parser.add_argument(
        "--bts_min_v",
        default=0.0,
        type=float,
        help="minimum return value",
    )
    parser.add_argument(
        "--bts_max_v",
        default=300.0,
        type=float,
        help="maximum return value",
    )
    parser.add_argument(
        "--bts_dropout_p",
        default=0.0,
        type=float,
        help="dropout probability in the return predictor network",
    )
    # Batch norm flag, allow --batch_norm=False for false
    parser.add_argument(
        "--bts_batch_norm",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="whether to use batch normalization in the return predictor network",
    )
    parser.add_argument(
        "--bts_learning_rate",
        default=1e-3,
        type=float,
        help="learning rate for the return predictor network",
    )
    parser.add_argument(
        "--bts_learning_rate_scheduler",
        default="constant",
        type=str,
        help="learning rate scheduler for the return predictor network",
    )
    parser.add_argument(
        "--bts_batch_size",
        default=256,
        type=int,
        help="batch size for the return predictor network",
    )
    parser.add_argument(
        "--bts_epochs",
        default=10,
        type=int,
        help="number of epochs to train the return predictor network",
    )
    parser.add_argument(
        "--bts_threshold_D",
        default=0.01,
        type=float,
        help="distance threshold for stitching two trajectories together"
    )
    parser.add_argument(
        "--bts_noise",
        default=0.0,
        type=float,
        help="std of noise injected into the bootstrapping process"
    )
    parser.add_argument(
        "--bts_feature_extractor",
        default='identity',
        type=str,
        help="feature extractor to use for bootstrapping knn"
    )
    parser.add_argument(
        "--bts_n_blocks",
        default=3,
        type=int,
        help="number of blocks in the return predictor ResNet"
    )
    parser.add_argument(
        "--bts_num_q",
        default=5,
        type=int,
        help="number of quantiles to use in the return predictor"
    )
    parser.add_argument(
        "--num_cpu",
        default=1,
        type=int,
        help="num_cpus to use for rollouts"
    )
    # Choice between knn and quantile
    parser.add_argument(
        "--bts_model",
        default="knn",
        type=str,
        choices=[
            "qr",
        ],
        help="The type of model to use for bootstrapping",
    )
    parser.add_argument(
        "--bts_val_frac",
        default=0.2,
        type=float,
        help="fraction of data to use for validation",
    )
    # Number of discrete actions during discretization
    parser.add_argument(
        "--discrete_clusters",
        type=int,
        default=256,
        help="number of discrete actions during discretization",
    )
    parser.add_argument(
        "--deterministic",
        type=lambda x: (str(x).lower() == "true"),
        default=True,
        help="whether to use deterministic actions during rollouts",
    )
    parser.add_argument(
        "--dataset_preprocess",
        type=str,
        choices=[
            "none",
            "slice_trajs"
        ],
        default="none",
        help="choice of what dataset preprocessing is done"
    )
    parser.add_argument(
        "--bts_relabel_style",
        type=str,
        choices=[
            "greedy",
            "singlepoint"
        ],
        default="singlepoint",
        help="how to relabel labels"
    )
    parser.add_argument(
        "--bts_only_sample_last",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="only sample the highest bootstrap iter from dataset"
    )
    parser.add_argument(
        "--bts_only_sample_last_policy",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="only sample the highest bootstrap iter from dataset"
    )
    parser.add_argument(
        "--bts_value_fn",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="use value function or q function when relabeling returns?"
    )
    parser.add_argument(
        "--bts_quantiles",
        type=int,
        default=5,
        help="average the last n quantiles for bootstrapping"
    )
    parser.add_argument(
        "--reward_preprocessing",
        type=str,
        choices=[
            'none',
            'antmaze',
            'conservative'
        ],
        default='none',
        help="choice of reward preprocessing"
    )
    parser.add_argument(
        "--bts_ensemble",
        type=str,
        choices=[
            'mean',
            'min'
        ],
        default='mean',
        help="choice of ensembling style for qr"
    )
    parser.add_argument(
        "--eval_return_quantile",
        type=int,
        default=None,
        help="which quantile to use for evaluation"
    )
    parser.add_argument(
        "--store_dataset_gpu",
        type=lambda x: (str(x).lower() == "true"),
        default=False,
        help="store dataset on gpu"
    )
    parser.add_argument(
        "--clean_wandb_after",
        action="store_true",
        help="delete the run's wandb directory after training"
    )
    parser.add_argument(
        "--percent_dataset",
        type=float,
        default=1.0,
        help="percent of dataset to use"
    )

    args = parser.parse_args()

    print(args.eval_return_quantile)

    if args.reward_targets is not None:
        args.reward_targets = [float(x)
                               for x in args.reward_targets.split(",")]

    if args.unconditional_policy and args.env_name not in step.d4rl_env_names:
        raise NotImplementedError

    pm = preemption.PreemptionManager(args.checkpoint_dir)

    args.seed = np.random.randint(
        2**31 - 1) if args.seed is None else args.seed
    util.set_seed(args.seed + 1)
    wandb_id = pm.wandb_id()
    wandb_logger = pl.loggers.wandb.WandbLogger(project=wandb_project, id=wandb_id,
                                                resume='allow', entity='rvsv')
    log_args(args, wandb_logger)
    device = util.configure_gpu(args.use_gpu, args.which_gpu)
    policy_env = step.create_env(
        args.env_name,
        args.max_episode_steps,
        args.discretize,
        discrete_clusters=args.discrete_clusters,
        seed=args.seed + 2,
    )

    if args.discretize_rollouts_only and args.env_name not in step.d4rl_env_names:
        rollout_env = step.create_env(
            args.env_name,
            args.max_episode_steps,
            True,
            discrete_clusters=args.discrete_clusters,
            seed=args.seed + 3,
        )
    else:
        rollout_env = policy_env

    if args.rollout_directory is not None:
        rollout_directory = args.rollout_directory
    elif args.env_name in step.d4rl_env_names:
        rollout_directory = None
    else:
        rollout_directory = os.path.join(
            wandb_logger.experiment.dir, rollout_dir)
        step.generate_random_rollouts(
            rollout_env,
            rollout_directory,
            args.total_steps,
            args.max_episode_steps,
            use_base_actions=args.discretize_rollouts_only,
        )

    bootstrap_model_args = {
        "hidden_size": args.bts_hidden_size,
        "n_layers": args.bts_n_layers,
        "n_bins": args.bts_n_bins,
        "min_v": args.bts_min_v,
        "max_v": args.bts_max_v,
        "batch_norm": args.bts_batch_norm,
        "dropout_p": args.bts_dropout_p,
        "learning_rate": args.bts_learning_rate,
        "learning_rate_scheduler": args.bts_learning_rate_scheduler,
        "batch_size": args.bts_batch_size,
        "epochs": args.bts_epochs,
        "val_frac": args.bts_val_frac,
        "n_blocks": args.bts_n_blocks,
        "num_q": args.bts_num_q,
        "reward_preprocessing": args.reward_preprocessing,
        "only_sample_last": args.bts_only_sample_last,
        "only_sample_last_policy": args.bts_only_sample_last_policy,
        "relabel_style": args.bts_relabel_style,
        "value_fn": args.bts_value_fn,
        "quantiles": args.bts_quantiles,
        "ensemble": args.bts_ensemble
    }

    run_training(
        profiler=args.profiler,
        env=policy_env,
        env_name=args.env_name,
        dataset_preprocess=args.dataset_preprocess,
        store_dataset_gpu=args.store_dataset_gpu,
        percent_dataset=args.percent_dataset,
        seed=args.seed,
        wandb_logger=wandb_logger,
        rollout_directory=rollout_directory,
        unconditional_policy=args.unconditional_policy,
        reward_conditioning=args.reward_conditioning,
        discount_factor=args.discount_factor,
        bootstrap_iters=args.bootstrap_iters,
        bootstrap_model=args.bts_model,
        bootstrap_model_args=bootstrap_model_args,
        bootstrap_threshold_D=args.bts_threshold_D,
        bootstrap_noise=args.bts_noise,
        bootstrap_feature_extractor=args.bts_feature_extractor,
        cumulative_reward_to_go=args.cumulative_reward_to_go,
        epochs=args.epochs,
        max_steps=args.max_steps,
        train_time=args.train_time,
        hidden_size=args.hidden_size,
        depth=args.depth,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        learning_rate_scheduler=args.learning_rate_scheduler,
        auto_tune_lr=args.auto_tune_lr,
        dropout_p=args.dropout_p,
        obs_noise=args.obs_noise,
        checkpoint_every_n_epochs=args.checkpoint_every_n_epochs,
        checkpoint_every_n_steps=args.checkpoint_every_n_steps,
        checkpoint_time_interval=args.checkpoint_time_interval,
        batch_size=args.batch_size,
        val_frac=args.val_frac,
        use_gpu=args.use_gpu,
        slurm_checkpoint_dir=args.checkpoint_dir
    )

    if args.visualize:
        visualize.visualize_performance(
            wandb_logger.experiment.dir,
            device,
            wandb_run=wandb_logger.experiment,
            trajectory_samples=2000
            if args.trajectory_samples is None
            else args.trajectory_samples,
        )
    if args.analyze_d4rl:
        analyze_d4rl.analyze_performance(
            wandb_logger.experiment.dir,
            device,
            wandb_run=wandb_logger.experiment,
            analysis=args.d4rl_analysis,
            trajectory_samples=200
            if args.trajectory_samples is None
            else args.trajectory_samples,
            last_checkpoints_too=not args.val_checkpoint_only,
        )

    if args.clean_wandb_after:
        wandb_logger.experiment.finish()
        import glob
        import shutil
        print(f"Removing {wandb_id} from local files...")
        for f in glob.glob(f"./wandb/*{wandb_id}*"):
            shutil.rmtree(f, ignore_errors=True)
        print("Done!")
