import flax.traverse_util
import wandb
import tqdm
import numpy as np
import tensorflow as tf

tf.config.set_visible_devices([], "GPU")  # GPU

import jax
import time
import flax
import jax.numpy as jnp

from dlimp.dataset import DLataset
from multinav.deploy.load_data import setup_datasets
from utils import average_dict, average_dicts
from multinav.deploy.train.agent import Agent

import orbax.checkpoint as ocp
import orbax
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)
from plotly import graph_objects as go


device_list = jax.devices()
num_devices = len(device_list)

from absl import app, flags, logging as absl_logging
from ml_collections import config_dict, config_flags, ConfigDict

FLAGS = flags.FLAGS


def main(_):
    tf.get_logger().setLevel("WARNING")
    absl_logging.set_verbosity("WARNING")

    model_config: ConfigDict = FLAGS.model_config
    data_config: ConfigDict = FLAGS.data_config

    def train_step(batch, agent, update_actor):
        if "bc" in model_config.agent_name:
            return agent.update(batch, pmap_axis="num_devices")
        else:
            return agent.update(batch, pmap_axis="num_devices", networks_to_update={"actor", "critic"} if update_actor else {"critic"})


    def batch_for_devices(dataset: DLataset, total_batch_size: int, num_devices: int):
        dataset = dataset.batch(total_batch_size // num_devices, drop_remainder=True, num_parallel_calls=None)
        dataset = dataset.batch(num_devices, drop_remainder=True, num_parallel_calls=None)
        return dataset

    # load data: entire dataset takes 6.2 seconds to load
    train_dataset, val_datasets = setup_datasets(
        FLAGS.data_mix,
        FLAGS.data_dir,
        discount=data_config.discount,
        skip_crash=data_config.skip_crash,
        truncate_goal=data_config.truncate_goal,
        validate=model_config.validate,
        reward_type=data_config.reward_type,
        negative_probability=data_config.negative_probability,
        action_type=data_config.action_type,
        image_size=model_config.image_size,
    )
    train_dataset = batch_for_devices(
        train_dataset, model_config.batch_size, num_devices
    )
    val_datasets = {
        k: batch_for_devices(v, model_config.batch_size, num_devices)
        for k, v in val_datasets.items()
    }

    print("Dataset set up")

    agent = Agent(model_config, FLAGS.seed)

    if FLAGS.checkpoint_load_dir is not None:
        agent.load_checkpoint(FLAGS.checkpoint_load_dir, FLAGS.checkpoint_load_step)

    agent.replicate()

    print("Agent set up!")
    pmap_train_step = jax.pmap(train_step, axis_name="num_devices", devices=device_list, static_broadcasted_argnums=(2,))

    datetime_string = time.strftime("%Y_%m_%d_%H_%M_%S")
    config_dict_flat = flax.traverse_util.flatten_dict(model_config.to_dict()) | flax.traverse_util.flatten_dict(data_config.to_dict())
    config_dict_flat = {"/".join(k): v for k, v in config_dict_flat.items()}
    wandb.init(
        project=str(model_config.wandb_proj),
        name=f"{FLAGS.wandb_name}_{datetime_string}".format(**config_dict_flat),
        config=model_config.to_dict() | data_config.to_dict(),
        dir=FLAGS.wandb_dir,
    )

    if FLAGS.checkpoint_save_dir is not None:
        checkpoint_manager = CheckpointManager(
            directory=tf.io.gfile.join(FLAGS.checkpoint_save_dir, wandb.run.name),
            checkpointers=PyTreeCheckpointer(),
            options=CheckpointManagerOptions(
                save_interval_steps=FLAGS.checkpoint_interval,  # set to 3000 for now to get more spread
                # max_to_keep=50,
            ),
        )
    else:
        checkpoint_manager = None

    train_data = train_dataset.iterator()
    training_data_prefetch = flax.jax_utils.prefetch_to_device(train_data, 2)

    if model_config.validate:
        val_data_prefetch = {
            k: flax.jax_utils.prefetch_to_device(DLataset.iterator(v), 2)
            for k, v in val_datasets.items()
        }

    # training loop
    for step in tqdm.trange(model_config.train_steps):
        batch = next(training_data_prefetch)
        agent.actor, update_info = pmap_train_step(
            batch, agent.actor, step % 3 == 0
        )  
        update_info = average_dict(update_info)  # compress from 8 dicts for 8 devices
        update_info = {f"train/{key}": value for key, value in update_info.items()}
        update_info["data_stats/reached_goal_frac"] = np.mean(batch["reached"])
        update_info["data_stats/original_goals"] = np.mean(batch["resample_type"] == 0)
        update_info["data_stats/positive_goals"] = np.mean(batch["resample_type"] == 1)
        update_info["data_stats/negative_goals"] = np.mean(batch["resample_type"] == 2)
        update_info["data_stats/crash_frac"] = np.mean(batch["crashed"])

        if step % FLAGS.wandb_interval == 0:
            wandb.log(
                update_info
                | {
                    "action_0": wandb.Histogram(batch["actions"][..., 0].flatten()),
                    "action_1": wandb.Histogram(batch["actions"][..., 1].flatten()),
                },
                step=step,
            )

        if model_config.validate and step % model_config.val_steps == 0:  # validation!
            val_info_all = {}
            for single_dataset_name, single_val_data in val_data_prefetch.items():
                val_metrics = []
                for _ in range(100):  # validation - 10 batches
                    val_batch = next(single_val_data)
                    _, update_info = pmap_train_step(val_batch, agent.actor, True)
                    update_info = average_dict(update_info)
                    update_info = {
                        f"val/{key}": value for key, value in update_info.items()
                    }
                    val_metrics.append(update_info)

                val_info_all[single_dataset_name] = average_dicts(val_metrics)
            wandb.log(val_info_all, step=step)

        if checkpoint_manager is not None and step % FLAGS.checkpoint_interval == 0:
            checkpoint_manager.save(
                step,
                items=jax.device_get(flax.jax_utils.unreplicate(agent.actor)),
            )

        if model_config.validate and step % FLAGS.viz_interval == 0:
            viz_info = {}
            for single_dataset_name, single_val_data in val_data_prefetch.items():
                # Sample actions on train data
                batch = next(single_val_data)
                sampled_actions = jax.pmap(
                    lambda batch, actor: actor.sample_actions(
                        batch["observations"], batch["goals"], seed=actor.state.rng
                    ),
                    axis_name="num_devices",
                    devices=device_list,
                )(batch, agent.actor)
                mode_actions = jax.pmap(
                    lambda batch, actor: actor.sample_actions(
                        batch["observations"], batch["goals"], argmax=True
                    ),
                    axis_name="num_devices",
                    devices=device_list,
                )(batch, agent.actor)
                sampled_actions = jax.device_get(sampled_actions)
                dataset_actions = jax.device_get(batch["actions"])

                # Scatter plot
                plot = go.Figure()
                plot.add_trace(
                    go.Scatter(
                        x=sampled_actions[..., 0].flatten(),
                        y=sampled_actions[..., 1].flatten(),
                        mode="markers",
                        name="sampled",
                    )
                )
                plot.add_trace(
                    go.Scatter(
                        x=mode_actions[..., 0].flatten(),
                        y=mode_actions[..., 1].flatten(),
                        mode="markers",
                        name="modes",
                    )
                )
                plot.add_trace(
                    go.Scatter(
                        x=dataset_actions[..., 0].flatten(),
                        y=dataset_actions[..., 1].flatten(),
                        mode="markers",
                        name="dataset",
                    )
                )
                viz_info[f"viz/{single_dataset_name}"] = wandb.Plotly(plot)

            # Add to wandb
            wandb.log(viz_info, step=step)

        # time.sleep(0.2)

    wandb.finish()


if __name__ == "__main__":
    import os

    config_flags.DEFINE_config_file(
        "model_config",
        os.path.join(os.path.dirname(__file__), "model_config.py:gc_bc"),
        "Configuration for the agent",
    )

    config_flags.DEFINE_config_file(
        "data_config",
        os.path.join(os.path.dirname(__file__), "data_config.py:gnm"),
        "Configuration for the agent",
    )

    flags.DEFINE_integer("seed", 42, "Seed for training")
    flags.DEFINE_string("checkpoint_save_dir", None, "Where to store checkpoints")
    flags.DEFINE_string("checkpoint_load_dir", None, "Where to load checkpoints")
    flags.DEFINE_integer("checkpoint_load_step", None, "Which step to load checkpoints")

    flags.DEFINE_string("data_dir", None, required=True, help="Dataset directory")
    flags.DEFINE_string("data_mix", "gnm", help="Dataset mix")
    flags.DEFINE_string("wandb_name", "{agent_name}_{action_type}_{reward_type}_skip{num_frame_skip}_alpha{agent_config/cql_alpha}_proprio{agent_config/critic_use_proprio}", help="Name of run on W&B")
    flags.DEFINE_string(
        "wandb_dir",
        <WANDB DIR>,
        "Where to store temporary W&B data to sync to cloud",
    )

    flags.DEFINE_integer("wandb_interval", 10, "Interval between calls to wandb.log")
    flags.DEFINE_integer(
        "checkpoint_interval", 10000, "Interval between checkpoing saves"
    )
    flags.DEFINE_integer("viz_interval", 5000, "Interval between visualizations")

    app.run(main)
