import time
import jax.experimental
import jax.experimental.compilation_cache
import jax.experimental.compilation_cache.compilation_cache
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') 
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)

from typing import Mapping
from agentlace.data.tf_agents_episode_buffer import EpisodicTFDataStore
from agentlace.trainer import TrainerServer
import wandb
import tqdm
import numpy as np

import sys 
import os

import jax
jax.experimental.compilation_cache.compilation_cache.initialize_cache(<CACHE PATH>)
import flax
import jax.numpy as jnp
from jaxrl_m.vision.data_augmentations import batched_random_crop
from jaxrl_m.utils.timer_utils import Timer

from dlimp.dataset import DLataset
from multinav.deploy.common.trainer_bridge_common import (
    task_data_format,
    make_trainer_config,
)
from multinav.deploy.train.agent import Agent
from multinav.deploy.train.utils import average_dict, average_dicts

from multinav.deploy.load_data import (
    dataset_postprocess,
    setup_datasets,
    dataset_preprocess,
)

from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)

from agentlace.data.rlds_writer import RLDSWriter
import atexit

device_list = jax.devices()
num_devices = len(device_list)

from absl import app, flags, logging as absl_logging
from ml_collections import config_flags, ConfigDict

FLAGS = flags.FLAGS

WAYPOINT_SPACING = 0.25
ANGLE_SCALE = 1 
X_OFFSET = -1 

MIN_LENGTH = 2

end_stats = {
    "crash": 0,
    "reach": 0,
    "timeout": 0,
    "total": 0,
}

@jax.jit
def augment(image, key):
    assert jnp.issubdtype(image.dtype, jnp.integer), image.dtype
    IMAGENET_MEAN = jnp.array([0.485, 0.456, 0.406])
    IMAGENET_STD = jnp.array([0.229, 0.224, 0.225])
    image = (image.astype(float) - IMAGENET_MEAN) / IMAGENET_STD

    return batched_random_crop(image, key, padding=4, num_batch_dims=1)

def main(_):

    global end_stats
    tf.get_logger().setLevel("WARNING")
    absl_logging.set_verbosity("WARNING")

    model_config: ConfigDict = FLAGS.model_config
    online_data_config: ConfigDict = FLAGS.online_data_config
    offline_data_config: ConfigDict = FLAGS.offline_data_config

    def train_step(batch, agent, update_actor):
        if "bc" in model_config.agent_name:
            return agent.update(batch, pmap_axis="num_devices")
        else:
            return agent.update(batch, pmap_axis="num_devices", networks_to_update={"actor", "critic"} if update_actor else {"critic"})

    def batch_for_devices(dataset: DLataset, total_batch_size: int, num_devices: int):
        dataset = dataset.batch(total_batch_size // num_devices, drop_remainder=True, num_parallel_calls=None)
        dataset = dataset.batch(num_devices, drop_remainder=True, num_parallel_calls=None)
        return dataset

    # loading an existing dataset 
    if FLAGS.data_mix is not None:
        train_dataset, val_datasets = setup_datasets(
            FLAGS.data_mix,
            FLAGS.data_dir,
            discount=offline_data_config.discount,
            skip_crash=offline_data_config.skip_crash,
            truncate_goal=offline_data_config.truncate_goal,
            validate=model_config.validate,
            reward_type=online_data_config.reward_type,
            negative_probability=offline_data_config.negative_probability,
            action_type=offline_data_config.action_type,
        )

        val_datasets = {
            k: batch_for_devices(v, model_config.batch_size, num_devices)
            for k, v in val_datasets.items()
        }
    
    else:
        train_dataset = None
        val_datasets = None

    agent_type = model_config.agent_name
    agent = Agent(model_config, FLAGS.seed)

    if FLAGS.checkpoint_load_dir is not None:
        agent.load_checkpoint(FLAGS.checkpoint_load_dir, FLAGS.checkpoint_load_step)

    agent.replicate()

    print("Agent set up!")
    pmap_train_step = jax.pmap(train_step, axis_name="num_devices", devices=device_list, static_broadcasted_argnums=(2,))

    wandb.init(
        project=model_config.wandb_proj,
        name=FLAGS.wandb_name,
        config=model_config.to_dict() | online_data_config.to_dict() | offline_data_config.to_dict(),
        dir=FLAGS.wandb_dir,
    )

    if FLAGS.checkpoint_save_dir is not None:
        print("setting up checkpointer")
        checkpoint_manager = CheckpointManager(
            directory=tf.io.gfile.join(FLAGS.checkpoint_save_dir, wandb.run.name),
            checkpointers=PyTreeCheckpointer(),
            options=CheckpointManagerOptions(
                save_interval_steps=FLAGS.checkpoint_interval, 
                max_to_keep=50,
            ),
        )
    else:
        checkpoint_manager = None

    # WITH SAVING 
    data_spec = task_data_format()
    data_dir = FLAGS.data_save_dir
    # existing_folders = [0] + [int(folder.split('.')[-1]) for folder in os.listdir(data_dir)]
    # latest_version = max(existing_folders)

    # version= f"0.0.{1 + latest_version}"
    # datastore_path = f"{data_dir}/{version}"
    # os.makedirs(datastore_path)
    version = "0.0.1"
    datastore_path = tf.io.gfile.join(data_dir, version)

    writer = RLDSWriter(
        dataset_name="test",
        data_spec = data_spec,
        data_directory = datastore_path,
        version = version,
        max_episodes_per_file = 100,
    )

    atexit.register(writer.close) # save on exit 

    online_dataset_datastore = EpisodicTFDataStore(
        capacity=10000,
        data_spec= task_data_format(),
        rlds_logger = writer
    )
    print("Datastore set up")

    # no saving 
    # online_dataset_datastore = EpisodicTFDataStore(
    #     capacity=10000,
    #     data_spec=task_data_format(),
    # )

    def request_callback(_type, _payload):
        if _type == "send-stats":
            global end_stats

            for key, value in _payload.items(): # reach, timeout, crash
                end_stats[key] += int(value)

            end_stats["total"] += 1

        elif _type == "get-model-config":
            return model_config # .agent_config.to_dict() 
        else:
            raise NotImplementedError(f"Unknown request type {_type}")

    train_server = TrainerServer(
        config=make_trainer_config(),
        request_callback=request_callback,
    )

    train_server.register_data_store("online_data", online_dataset_datastore)
    train_server.start(threaded=True)

    samples_to_wait_for = FLAGS.wait_data  # usually 1000
    pbar = tqdm.tqdm(total=samples_to_wait_for, desc="Waiting for data")
    while online_dataset_datastore.size < samples_to_wait_for:
        pbar.update(online_dataset_datastore.size - pbar.n)
        train_server.publish_network(
            {
                "params": jax.tree_map(
                    np.asarray, flax.jax_utils.unreplicate(agent.actor.state.params)
                )
            }
        )
    
    train_server.stop() # Stop while processing the first batch(es)

    online_dataset = dataset_preprocess(
        online_dataset_datastore.as_dataset().ignore_errors(log_warning=True, name="online_data"),
        waypoint_spacing = WAYPOINT_SPACING,
        x_offset = X_OFFSET,
        angle_scale = ANGLE_SCALE,
        assign_goal=True,
        end_is_crash = False,
        discount=online_data_config.discount,
        negative_probability = online_data_config.negative_probability, # 80% REAL 
        min_length= 3,
        skip_crash=online_data_config.skip_crash,
        discrete=False,
        truncate_goal=online_data_config.truncate_goal,
        action_key="action",
        has_goal = True,
        reward_type = online_data_config.reward_type,
        action_type = online_data_config.action_type,
    )
    online_dataset = dataset_postprocess(online_dataset, image_size = 64)
 
    if train_dataset is None: # online only 
        full_train_dataset = online_dataset
    else: # mix in with existing dataset
        print("Mixing Offline and Online Data")
        # full_train_dataset = DLataset.sample_from_datasets(
        #         [train_dataset, online_dataset],
        #         weights = [0.5, 0.5])

        dataset_mixes = [DLataset.sample_from_datasets(
                            [online_dataset, train_dataset], weights = [split, 1 - split])
                    for split in [0.1, 0.2, 0.3, 0.4, 0.5]]
        dataset_mixes = [batch_for_devices(
                            dataset, model_config.batch_size, num_devices
                        ) for dataset in dataset_mixes] 
        dataset_mixes = [dataset.iterator() for dataset in dataset_mixes]

        curr_data_mix = 0
        datamix_switch_step = 2000

    if model_config.validate:
        val_data_prefetch = {
            k: flax.jax_utils.prefetch_to_device(DLataset.iterator(v), 2)
            for k, v in val_datasets.items()
        }

    timer = Timer()
    print("Starting Training Loop")
    pbar = tqdm.trange(model_config.train_steps, dynamic_ncols=True)
    for step in pbar:
        if step == 3:
            train_server.start(threaded=True)
        pbar.set_postfix({"online data": online_dataset_datastore.size})
        with timer.context("dataset_switching"):
            if step % datamix_switch_step == 0:
                if curr_data_mix < 5: # that's all we have!
                    training_data_prefetch = flax.jax_utils.prefetch_to_device(dataset_mixes[curr_data_mix], 2)
                    curr_data_mix += 1


        with timer.context("sample_buffer"):

            batch = None
            attempts = 0
            while batch is None and attempts < 12:
                try:
                    attempts += 1
                    batch = next(training_data_prefetch)
                except Exception as e:
                    print(f"Error processing batch at step {step}: {e}")

            if attempts >= 12:
                print("Could not successfully get next batch!")
                sys.exit()
        
        with timer.context("train_step"):
            agent.actor, update_info = pmap_train_step(
                batch, agent.actor, step % 3 == 0,
            )  
            update_info = average_dict(update_info)  # compress from 8 dicts for 8 devices
            update_info = {f"train/{key}": value for key, value in update_info.items()}
            update_info["data_stats/reached_goal_frac"] = np.mean(batch["reached"])
            update_info["data_stats/original_goals"] = np.mean(batch["resample_type"] == 0)
            update_info["data_stats/positive_goals"] = np.mean(batch["resample_type"] == 1)
            update_info["data_stats/negative_goals"] = np.mean(batch["resample_type"] == 2)
            update_info["data_stats/crash_frac"] = np.mean(batch["crashed"])


        if step % FLAGS.wandb_interval == 0:
            wandb.log(
                update_info
                | {
                    "action_0": wandb.Histogram(batch["actions"][..., 0].flatten()),
                    "action_1": wandb.Histogram(batch["actions"][..., 1].flatten()),
                    "online_size": online_dataset_datastore.size,
                    "online_reach_percent": end_stats["reach"] / (end_stats["total"] + 1e-3),
                    "online_crash_percent": end_stats["crash"] / (end_stats["total"] + 1e-3),
                    "online_timeout_percent": end_stats["timeout"] / (end_stats["total"] + 1e-3),
                    "online_reach_num": end_stats["reach"] ,
                    "online_crash_num": end_stats["crash"] ,
                    "online_timeout_num": end_stats["timeout"] ,
                } | {f"timer/{k}": v for k, v in timer.get_average_times().items()},
                step=step,
            )

        with timer.context("log_save"):
            if checkpoint_manager is not None and step % FLAGS.checkpoint_interval == 0:
                checkpoint_manager.save(
                    step,
                    items=jax.device_get(flax.jax_utils.unreplicate(agent.actor)),
                )

            # REMOTE TRAINING, send over new weights
            if step % 25 == 0:
                train_server.publish_network(
                    {
                        "params": jax.tree_map(
                            np.asarray, flax.jax_utils.unreplicate(agent.actor.state.params)
                        )
                    }
                )

    wandb.finish()


if __name__ == "__main__":
    import os

    config_flags.DEFINE_config_file(
        "model_config",
        os.path.join(os.path.dirname(__file__), "model_config.py:gc_bc"),
        "Configuration for the agent",
    )

    config_flags.DEFINE_config_file(
        "offline_data_config",
        os.path.join(os.path.dirname(__file__), "data_config.py:gnm"),
        "Configuration for the agent",
    )

    config_flags.DEFINE_config_file(
        "online_data_config",
        os.path.join(os.path.dirname(__file__), "data_config.py:create"),
        "Configuration for the agent",
    )

    flags.DEFINE_integer("seed", 42, "Seed for training")
    flags.DEFINE_integer("wait_data", 1000, "how many data points to wait for")

    flags.DEFINE_string("checkpoint_save_dir", None, "Where to store checkpoints")
    flags.DEFINE_string("checkpoint_load_dir", None, "Where to load checkpoints")
    flags.DEFINE_integer("checkpoint_load_step", None, "Which step to load checkpoints")

    flags.DEFINE_string("data_dir", None, help="Dataset directory")
    flags.DEFINE_string("data_mix", None, help="Dataset mix")
    flags.DEFINE_string("data_save_dir", None, "Where to save collected data")

    flags.DEFINE_string("wandb_name", None, help="Name of run on W&B")
    flags.DEFINE_string(
        "wandb_dir", "/tmp/wandb", "Where to store temporary W&B data to sync to cloud"
    )

    
    flags.DEFINE_integer("dataset_update_interval", 100, "Interval between dataset updates")
    flags.DEFINE_integer("wandb_interval", 10, "Interval between calls to wandb.log")
    flags.DEFINE_integer(
        "checkpoint_interval", 2000, "Interval between checkpoint saves"
    )
    flags.DEFINE_integer("viz_interval", 5000, "Interval between visualizations")


    app.run(main)