import os
from functools import partial

import gym
import jax
import numpy as np
import tqdm
from absl import app, flags, logging
from flax.training import checkpoints
from ml_collections import config_flags

from experiments.configs.ensemble_config import add_redq_config
from wsrl.agents import agents
from wsrl.common.evaluation import evaluate_with_trajectories
from wsrl.common.wandb import WandBLogger
from wsrl.data.replay_buffer import ReplayBuffer, ReplayBufferMC
from wsrl.envs.adroit_binary_dataset import get_hand_dataset_with_mc_calculation
from wsrl.envs.d4rl_dataset import (
    get_d4rl_dataset,
    get_d4rl_dataset_with_mc_calculation,
)
from wsrl.envs.env_common import get_env_type, make_gym_env
from wsrl.utils.timer_utils import Timer
from wsrl.utils.train_utils import concatenate_batches, subsample_batch

FLAGS = flags.FLAGS

# env
flags.DEFINE_string("env", "antmaze-large-diverse-v2", "Environemnt to use")
flags.DEFINE_float("reward_scale", 1.0, "Reward scale.")
flags.DEFINE_float("reward_bias", -1.0, "Reward bias.")
flags.DEFINE_float(
    "clip_action",
    0.99999,
    "Clip actions to be between [-n, n]. This is needed for tanh policies.",
)

# training
flags.DEFINE_integer("num_offline_steps", 1_000_000, "Number of offline epochs.")
flags.DEFINE_integer("num_online_steps", 500_000, "Number of online epochs.")
flags.DEFINE_float(
    "offline_data_ratio",
    0.0,
    "How much offline data to retain in each online batch update",
)
flags.DEFINE_string(
    "online_sampling_method",
    "mixed",
    """Method of sampling data during online update: mixed or append.
    `mixed` samples from a mix of offline and online data according to offline_data_ratio.
    `append` adds offline data to replay buffer and samples from it.""",
)

flags.DEFINE_bool(
    "online_use_cql_loss",
    True,
    """When agent is CQL/CalQL, whether to use CQL loss for the online phase (use SAC loss if False)""",
)
flags.DEFINE_integer(
    "warmup_steps", 0, "number of warmup steps (WSRL) before performing online updates"
)

# agent
flags.DEFINE_string("agent", "calql", "what RL agent to use")
flags.DEFINE_integer("utd", 1, "update-to-data ratio of the critic")
flags.DEFINE_integer("batch_size", 256, "batch size for training")
flags.DEFINE_integer("replay_buffer_capacity", int(3e6), "Replay buffer capacity")
flags.DEFINE_bool("use_redq", True, "Use an ensemble of Q-functions for the agent")

# experiment house keeping
flags.DEFINE_integer("seed", 0, "Random seed.")
flags.DEFINE_string(
    "save_dir",
    os.path.expanduser("~/o2o_log"),
    "Directory to save the logs and checkpoints",
)
flags.DEFINE_string("resume_path", "", "Path to resume from")
flags.DEFINE_integer("log_interval", 1_000, "Log every n steps")
flags.DEFINE_integer("eval_interval", 5_000, "Evaluate every n steps")
flags.DEFINE_integer("save_interval", 100_000, "Save every n steps.")
flags.DEFINE_integer(
    "n_eval_trajs", 20, "Number of trajectories to use for each evaluation."
)
flags.DEFINE_bool("deterministic_eval", True, "Whether to use deterministic evaluation")

# wandb
flags.DEFINE_string("exp_name", "", "Experiment name for wandb logging")
flags.DEFINE_string("project", None, "Wandb project folder")
flags.DEFINE_string("group", None, "Wandb group of the experiment")
flags.DEFINE_string("job_type", None, "Wandb job type of the experiment")
flags.DEFINE_bool("debug", False, "If true, no logging to wandb")
flags.DEFINE_integer("step_bias", 0, "log step bias")

flags.DEFINE_float("actor_bc_coef", None, "BC coefficient for ReBRAC actor")
flags.DEFINE_float("critic_bc_coef", None, "BC coefficient for ReBRAC critic")
flags.DEFINE_bool("normalize_q", True, "normalize_q for ReBRAC")
flags.DEFINE_float("learning_rate", None, "Learning rate for online updates")
flags.DEFINE_bool("use_bc", False, "Use BC for TD3")

config_flags.DEFINE_config_file(
    "config",
    None,
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)


def compute_average_trajectory_return(dataset):
    """
    Compute the average return of trajectories in the dataset.
    
    Args:
        dataset: Dictionary containing 'rewards' and 'dones' arrays
        
    Returns:
        float: Average return across all trajectories
    """
    rewards = dataset['rewards']
    dones = dataset['dones']
    
    trajectory_returns = []
    trajectory_lengths = []
    current_return = 0.0
    current_length = 0
    
    for i in range(len(rewards)):
        current_return += rewards[i]
        current_length += 1
        
        # If done=1, this is the end of a trajectory
        if dones[i] == 1:
            trajectory_returns.append(current_return)
            trajectory_lengths.append(current_length)
            current_return = 0.0
            current_length = 0
    
    # Handle case where last trajectory doesn't end with done=1
    if current_return != 0.0:
        trajectory_returns.append(current_return)
        trajectory_lengths.append(current_length)
    
    if len(trajectory_returns) == 0:
        return 0.0
    
    average_return = np.mean(trajectory_returns)
    average_length = np.mean(trajectory_lengths)
    
    print(f"Number of trajectories: {len(trajectory_returns)}")
    print(f"Trajectory returns - Min: {np.min(trajectory_returns):.2f}, "
          f"Max: {np.max(trajectory_returns):.2f}, "
          f"Mean: {average_return:.2f}, "
          f"Std: {np.std(trajectory_returns):.2f}")
    print(f"Trajectory lengths - Min: {np.min(trajectory_lengths)}, "
          f"Max: {np.max(trajectory_lengths)}, "
          f"Mean: {average_length:.2f}, "
          f"Std: {np.std(trajectory_lengths):.2f}")
    
    return average_return


def main(_):
    """
    house keeping
    """
    assert FLAGS.online_sampling_method in [
        "mixed",
        "append",
    ], "incorrect online sampling method"

    if FLAGS.use_redq:
        FLAGS.config.agent_kwargs = add_redq_config(FLAGS.config.agent_kwargs)

    if FLAGS.agent == "rebrac":
        if FLAGS.actor_bc_coef is not None:
            FLAGS.config.agent_kwargs['actor_bc_coef'] = FLAGS.actor_bc_coef
            print(f"Setting actor_bc_coef to {FLAGS.actor_bc_coef}")
        if FLAGS.critic_bc_coef is not None:
            FLAGS.config.agent_kwargs['critic_bc_coef'] = FLAGS.critic_bc_coef
            print(f"Setting critic_bc_coef to {FLAGS.critic_bc_coef}")
        if FLAGS.normalize_q is not None:
            FLAGS.config.agent_kwargs['normalize_q'] = FLAGS.normalize_q
            print(f"Setting normalize_q to {FLAGS.normalize_q}")
    if FLAGS.agent == "td3" and FLAGS.use_bc:
        FLAGS.config.agent_kwargs['use_bc'] = True
        print(f"Setting use_bc to {FLAGS.use_bc}")
        if FLAGS.normalize_q is not None:
            FLAGS.config.agent_kwargs['normalize_q'] = FLAGS.normalize_q
            print(f"Setting normalize_q to {FLAGS.normalize_q}")

    if FLAGS.learning_rate is not None:
        FLAGS.config.agent_kwargs['actor_optimizer_kwargs']['learning_rate'] = FLAGS.learning_rate
        FLAGS.config.agent_kwargs['critic_optimizer_kwargs']['learning_rate'] = FLAGS.learning_rate
        print(f"Setting learning_rate to {FLAGS.learning_rate}")

    min_steps_to_update = FLAGS.batch_size * (1 - FLAGS.offline_data_ratio)
    if FLAGS.agent == "calql":
        min_steps_to_update = max(
            min_steps_to_update, gym.make(FLAGS.env)._max_episode_steps
        )

    """
    wandb and logging
    """
    wandb_config = WandBLogger.get_default_config()
    wandb_config.update(
        {
            "project": FLAGS.project or "wsrl",
            "group": FLAGS.group or "wsrl",
            "job_type": FLAGS.job_type or "wsrl",
            "exp_descriptor": f"{FLAGS.exp_name}_{FLAGS.env}_{FLAGS.agent}_seed{FLAGS.seed}",
        }
    )

    wandb_logger = WandBLogger(
        wandb_config=wandb_config,
        variant=FLAGS.config.to_dict(),
        random_str_in_identifier=True,
        disable_online_logging=FLAGS.debug,
    )

    save_dir = os.path.join(
        FLAGS.save_dir,
        wandb_logger.config.project,
        f"{wandb_logger.config.exp_descriptor}_{wandb_logger.config.unique_identifier}",
    )

    """
    env
    """
    env_type = get_env_type(FLAGS.env)
    finetune_env = make_gym_env(
        env_name=FLAGS.env,
        reward_scale=FLAGS.reward_scale,
        reward_bias=FLAGS.reward_bias,
        scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"),
        action_clip_lim=FLAGS.clip_action,
        seed=FLAGS.seed,
    )
    eval_env = make_gym_env(
        env_name=FLAGS.env,
        scale_and_clip_action=env_type in ("antmaze", "kitchen", "locomotion"),
        action_clip_lim=FLAGS.clip_action,
        seed=FLAGS.seed + 1000,
    )

    """
    load dataset
    """
    if env_type == "adroit-binary":
        dataset = get_hand_dataset_with_mc_calculation(
            FLAGS.env,
            gamma=FLAGS.config.agent_kwargs.discount,
            reward_scale=FLAGS.reward_scale,
            reward_bias=FLAGS.reward_bias,
            clip_action=FLAGS.clip_action,
        )
    else:
        if FLAGS.agent == "calql":
            # need dataset with mc return
            dataset = get_d4rl_dataset_with_mc_calculation(
                FLAGS.env,
                reward_scale=FLAGS.reward_scale,
                reward_bias=FLAGS.reward_bias,
                clip_action=FLAGS.clip_action,
                gamma=FLAGS.config.agent_kwargs.discount,
            )
        else:
            dataset = get_d4rl_dataset(
                FLAGS.env,
                reward_scale=FLAGS.reward_scale,
                reward_bias=FLAGS.reward_bias,
                clip_action=FLAGS.clip_action,
            )


    """
    replay buffer
    """
    replay_buffer_type = ReplayBufferMC if FLAGS.agent == "calql" else ReplayBuffer
    replay_buffer = replay_buffer_type(
        finetune_env.observation_space,
        finetune_env.action_space,
        capacity=FLAGS.replay_buffer_capacity,
        seed=FLAGS.seed,
        discount=FLAGS.config.agent_kwargs.discount if FLAGS.agent == "calql" else None,
    )

    """
    Initialize agent
    """
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, construct_rng = jax.random.split(rng)
    example_batch = subsample_batch(dataset, FLAGS.batch_size)
    agent = agents[FLAGS.agent].create(
        rng=construct_rng,
        observations=example_batch["observations"],
        actions=example_batch["actions"],
        encoder_def=None,
        **FLAGS.config.agent_kwargs,
    )

    if FLAGS.resume_path != "":
        print(f"Resuming from {FLAGS.resume_path}")
        assert os.path.exists(FLAGS.resume_path), "resume path does not exist"
        agent = checkpoints.restore_checkpoint(FLAGS.resume_path, target=agent)

    """
    eval function
    """

    def evaluate_and_log_results(
        eval_env,
        policy_fn,
        eval_func,
        step_number,
        wandb_logger,
        n_eval_trajs=FLAGS.n_eval_trajs,
    ):
        stats, trajs = eval_func(
            policy_fn,
            eval_env,
            n_eval_trajs,
        )

        # Apply reward scaling and bias
        scaled_rewards = [
            np.array(t["rewards"]) * FLAGS.reward_scale + FLAGS.reward_bias 
            for t in trajs
        ]
        eval_info = {
            "average_return": np.mean([np.sum(reward) for reward in scaled_rewards]),
            "average_traj_length": np.mean([len(t["rewards"]) for t in trajs]),
        }
        if env_type == "adroit-binary":
            # adroit
            eval_info["success_rate"] = np.mean(
                [any(d["goal_achieved"] for d in t["infos"]) for t in trajs]
            )
        elif env_type == "kitchen":
            # kitchen
            eval_info["num_stages_solved"] = np.mean([t["rewards"][-1] for t in trajs])
            eval_info["success_rate"] = np.mean([t["rewards"][-1] for t in trajs]) / 4

        else:
            # d4rl antmaze, locomotion
            eval_info["success_rate"] = eval_info[
                "average_normalized_return"
            ] = np.mean(
                [eval_env.get_normalized_score(np.sum(t["rewards"])) for t in trajs]
            )
        wandb_logger.log({"evaluation": eval_info}, step=step_number)

    """
    training loop
    """
    timer = Timer()
    step = int(agent.state.step)  # 0 for new agents, or load from pre-trained
    step += FLAGS.step_bias
    is_online_stage = False
    observation, info = finetune_env.reset()
    done = False  # env done signal

    # for _ in tqdm.tqdm(range(step, FLAGS.num_offline_steps + FLAGS.num_online_steps)):
    if FLAGS.num_online_steps > 0:
        train_steps = FLAGS.num_online_steps
    else:
        train_steps = FLAGS.num_offline_steps
    for i in range(train_steps):
        """
        Switch from offline to online
        """
        if not is_online_stage and step >= FLAGS.num_offline_steps:
            logging.info("Switching to online training")
            is_online_stage = True

            # upload offline data to online buffer
            if FLAGS.online_sampling_method == "append":
                offline_dataset_size = dataset["actions"].shape[0]
                dataset_items = dataset.items()
                for j in range(offline_dataset_size):
                    transition = {k: v[j] for k, v in dataset_items}
                    # For ReplayBufferMC, we need to ensure mc_returns are included
                    if FLAGS.agent == "calql" and "mc_returns" in transition:
                        # CalQL uses a ReplayBufferMC which expects mc_returns
                        replay_buffer.insert(transition)
                    else:
                        # For regular ReplayBuffer, exclude mc_returns if present
                        keys_to_keep = ['observations', 'next_observations', 'actions', 'rewards', 'masks', 'dones']
                        filtered_transition = {k: transition[k] for k in keys_to_keep if k in transition}
                        replay_buffer.insert(filtered_transition)

            # option for CQL and CalQL to change the online alpha, and whether to use CQL regularizer
            if FLAGS.agent in ("cql", "calql"):
                online_agent_configs = {
                    "cql_alpha": FLAGS.config.agent_kwargs.get(
                        "online_cql_alpha", None
                    ),
                    "use_cql_loss": FLAGS.online_use_cql_loss,
                }
                agent.update_config(online_agent_configs)

        timer.tick("total")

        """
        Env Step
        """
        with timer.context("env step"):
            if is_online_stage:
                rng, action_rng = jax.random.split(rng)
                action = agent.sample_actions(observation, seed=action_rng)
                next_observation, reward, done, truncated, info = finetune_env.step(
                    action
                )

                transition = dict(
                    observations=observation,
                    next_observations=next_observation,
                    actions=action,
                    rewards=reward,
                    masks=1.0 - done,
                    dones=1.0 if (done or truncated) else 0,
                )
                replay_buffer.insert(transition)

                observation = next_observation
                if done or truncated:
                    observation, info = finetune_env.reset()
                    done = False

        """
        Updates
        """
        with timer.context("update"):
            # offline updates
            if not is_online_stage:
                #if step % FLAGS.log_interval == 0:
                if step % 1000 == 0:
                    batch = subsample_batch(dataset, 1024)
                    offline_data_metrics_info = agent.get_metrics(
                        batch,
                    )
                    offline_data_metrics_info = jax.device_get(offline_data_metrics_info)
                    wandb_logger.log({"offline_data_metrics": offline_data_metrics_info}, step=step)
                batch = subsample_batch(dataset, FLAGS.batch_size)
                if FLAGS.agent in ("rebrac", "td3"):
                    agent, update_info = agent.update_with_policy_delay(
                        batch,
                        step,
                    )
                else:
                    agent, update_info = agent.update(
                        batch,
                    )

            # online updates
            else:
                if step - FLAGS.num_offline_steps <= max(
                    FLAGS.warmup_steps, min_steps_to_update
                ):
                    # no updates during warmup
                    pass
                else:
                    # do online updates, gather batch
                    if FLAGS.online_sampling_method == "mixed":
                        # batch from a mixing ratio of offline and online data
                        batch_size_offline = int(
                            FLAGS.batch_size * FLAGS.offline_data_ratio
                        )
                        batch_size_online = FLAGS.batch_size - batch_size_offline
                        online_batch = replay_buffer.sample(batch_size_online)
                        offline_batch = subsample_batch(dataset, batch_size_offline)
                        # update with the combined batch
                        batch = concatenate_batches([online_batch, offline_batch])
                    elif FLAGS.online_sampling_method == "append":
                        # batch from online replay buffer, with is initialized with offline data
                        batch = replay_buffer.sample(FLAGS.batch_size)
                    else:
                        raise RuntimeError("Incorrect online sampling method")
                    # log metrics
                    #if step % FLAGS.log_interval == 0:
                    if step % 1000 == 0:
                        batch_offline = subsample_batch(dataset, 1024)
                        offline_data_metrics_info = agent.get_metrics(
                            batch_offline,
                        )
                        batch_online = replay_buffer.sample(1024)
                        online_data_metrics_info = agent.get_metrics(
                            batch_online,
                        )
                        
                        offline_data_metrics_info = jax.device_get(offline_data_metrics_info)
                        wandb_logger.log({"offline_data_metrics": offline_data_metrics_info}, step=step)
                        
                        online_data_metrics_info = jax.device_get(online_data_metrics_info)
                        wandb_logger.log({"online_data_metrics": online_data_metrics_info}, step=step)

                    # update
                    
                    if FLAGS.utd > 1:
                            
                        agent, update_info = agent.update_high_utd(
                            batch,
                            utd_ratio=FLAGS.utd,
                        )
                    else:
                        if FLAGS.agent in ("rebrac", "td3"):
                            agent, update_info = agent.update_with_policy_delay(
                                batch,
                                step,
                            )
                        else:
                            agent, update_info = agent.update(
                                batch,
                            )

        """
        Advance Step
        """
        step += 1

        """
        Evals
        """
        eval_steps = (
            FLAGS.num_offline_steps,  # finish offline training
            FLAGS.num_offline_steps + 1,  # start of online training
            FLAGS.num_offline_steps + FLAGS.num_online_steps,  # end of online training
        )
        if step % FLAGS.eval_interval == 0 or step in eval_steps:
            logging.info("Evaluating...")
            with timer.context("evaluation"):
                policy_fn = partial(
                    agent.sample_actions, argmax=FLAGS.deterministic_eval
                )
                eval_func = partial(
                    evaluate_with_trajectories, clip_action=FLAGS.clip_action
                )

                evaluate_and_log_results(
                    eval_env=eval_env,
                    policy_fn=policy_fn,
                    eval_func=eval_func,
                    step_number=step,
                    wandb_logger=wandb_logger,
                )

        """
        Save Checkpoint
        """
        if step % FLAGS.save_interval == 0 or step == FLAGS.num_offline_steps:
            logging.info("Saving checkpoint...")
            checkpoint_path = checkpoints.save_checkpoint(
                save_dir, agent, step=step, keep=30
            )
            logging.info("Saved checkpoint to %s", checkpoint_path)

        timer.tock("total")

        """
        Logging
        """
        offset = 1 if FLAGS.agent in ("td3", "rebrac") else 0

        if step % FLAGS.log_interval == offset:
            # check if update_info is available (False during warmup)
            if "update_info" in locals():
                update_info = jax.device_get(update_info)
                wandb_logger.log({"training": update_info}, step=step)

            wandb_logger.log({"timer": timer.get_average_times()}, step=step)


if __name__ == "__main__":
    app.run(main)
