import os
from typing import Tuple
from flax.training import train_state
from flax.serialization import to_bytes, from_bytes
import gym
import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
from common import BehaviorCloning
import wrappers
from dataset_utils_our import Batch, D4RLDataset, PrioritizedReplayBuffer, split_into_trajectories
from evaluation import evaluate
from learner import Learner
import warnings
import optax
import jax.numpy as jnp
import jax


warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_string('load_dir', './tmp/', 'Load directory.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 5,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('eval_interval', 2000, 'Eval interval.')

flags.DEFINE_float('l', 1, 'weight to priorities')
flags.DEFINE_float('gamma', 1, 'weight to q')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(6e4), 'Number of training steps.')
flags.DEFINE_integer('num_pretraining_steps', int(0), 'Number of pretraining steps.')
flags.DEFINE_integer('replay_buffer_size', 3000000, 'Replay buffer size (=max_steps if unspecified).')
flags.DEFINE_integer('init_dataset_size', None, 'Offline data size (uses all data if unspecified).')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')

config_flags.DEFINE_config_file(
    'config',
    'configs/antmaze_finetune_config.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)

RETURN_MAX: float = None
RETURN_MIN: float = None


def normalize(dataset):
    global RETURN_MAX, RETURN_MIN
    if RETURN_MAX is not None and RETURN_MIN is not None:
        return
    
    trajs = split_into_trajectories(dataset.observations, dataset.actions,
                                    dataset.rewards, dataset.masks,
                                    dataset.dones_float,
                                    dataset.next_observations)

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    trajs.sort(key=compute_returns)
    RETURN_MAX, RETURN_MIN = compute_returns(trajs[-1]), compute_returns(trajs[0])
    # dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
    # dataset.rewards *= 1000.0


def make_env_and_dataset(
        env_name: str,
        seed: int,
) -> Tuple[gym.Env, D4RLDataset]:
    env = gym.make(env_name)

    env = wrappers.EpisodeMonitor(env)
    env = wrappers.SinglePrecision(env)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    dataset = D4RLDataset(env)

    if 'antmaze' in FLAGS.env_name:
        pass
    elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name
          or 'hopper' in FLAGS.env_name):
        normalize(dataset)

    return env, dataset

def load_checkpoint(checkpoint_dir, model, optimizer):
    with open(checkpoint_dir, 'rb') as f:
        checkpoint_data = f.read()
    params = from_bytes(None, checkpoint_data)['params']
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
    return state

@jax.jit
def compute_sample_weights(agent_state, batch, alpha, l):
    # Vectorized action calculation using JAX
    actions_bc = agent_state.apply_fn({'params': agent_state.params["params"]}, batch.observations)
    priorities = jnp.power(jnp.mean(jnp.square(actions_bc - batch.actions), axis=1)/l + 1, alpha)

    return priorities

@jax.jit
def compute_q_weights(agent_state, batch, gamma):
    # Vectorized action calculation using JAX
    actions_bc = agent_state.apply_fn({'params': agent_state.params["params"]}, batch.observations)
    action_error = jnp.mean(jnp.square(actions_bc - batch.actions), axis=1) / gamma
    sample_weights = jnp.exp(-action_error)
    
    return sample_weights

@jax.jit
def compute_error(agent_state, observation, action, alpha, l):
    # Vectorized action calculation using JAX
    action_bc = agent_state.apply_fn({'params': agent_state.params["params"]}, observation)
    priority = jnp.power(jnp.mean(jnp.square(action_bc - action))/l + 1, alpha)

    return priority

def main(_):
    os.makedirs(FLAGS.save_dir, exist_ok=True)

    env, dataset = make_env_and_dataset(FLAGS.env_name, FLAGS.seed)

    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]
    replay_buffer = PrioritizedReplayBuffer(env.observation_space, action_dim,
                                 FLAGS.replay_buffer_size)
    
    replay_buffer.initialize_with_dataset(dataset, FLAGS.init_dataset_size)

    kwargs = dict(FLAGS.config)
    agent = Learner(
        seed=FLAGS.seed,
        observations=env.observation_space.sample()[np.newaxis],
        actions=env.action_space.sample()[np.newaxis],
        **kwargs)
    agent_bc = BehaviorCloning(state_dim=state_dim, action_dim=action_dim, hidden_dim=256)

    optimizer = optax.adamw(learning_rate=1e-3)
    checkpoint_dir = './models/bc_model/'+FLAGS.env_name+'/checkpoints/checkpoint_2000000.ckpt'

    # Load the model checkpoint
    agent_state = load_checkpoint(checkpoint_dir, agent_bc, optimizer)
    agent.load_checkpoint(FLAGS.load_dir)
    
    eval_returns = []
    observation, done = env.reset(), False
    
    # Use negative indices for pretraining steps.
    for i in tqdm.tqdm(range(FLAGS.max_steps + 1),
                       smoothing=0.1,
                       disable=not FLAGS.tqdm,
                       ncols=80):

        action = agent.sample_actions(observation, )
        action = np.clip(action, -1, 1)
        next_observation, reward, done, info = env.step(action)

        if not done or 'TimeLimit.truncated' in info:
            mask = 1.0
        else:
            mask = 0.0

        priority = compute_error(agent_state, observation, action, replay_buffer.alpha, FLAGS.l)
        replay_buffer.insert(
        observation=observation,
        action=action,
        reward=reward,
        mask=mask,
        done_float=float(done),
        next_observation=next_observation,
        priority=priority
        )
    
        observation = next_observation

        if done:
            observation, done = env.reset(), False

        batch, indices = replay_buffer.sample(FLAGS.batch_size)

        # Adjust rewards based on environment type
        if 'antmaze' in FLAGS.env_name:
            batch = Batch(
                observations=batch.observations,
                actions=batch.actions,
                rewards=batch.rewards - 1,  # Adjust rewards for antmaze
                masks=batch.masks,
                next_observations=batch.next_observations
            )
        elif any(env in FLAGS.env_name for env in ['halfcheetah', 'walker2d', 'hopper']):
            batch = Batch(
                observations=batch.observations,
                actions=batch.actions,
                rewards=(batch.rewards / (RETURN_MAX - RETURN_MIN)) * 1000.0,  # Normalize rewards
                masks=batch.masks,
                next_observations=batch.next_observations
            )

        sample_weights = compute_q_weights(agent_state, batch, FLAGS.gamma)
        update_info = agent.update(batch, sample_weights)

        if i > 0 and i % FLAGS.eval_interval == 0: 
            eval_stats = evaluate(agent, env, FLAGS.eval_episodes)

            eval_returns.append((i, eval_stats['return']))
            np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
                       eval_returns,
                       fmt=['%d', '%.4f'])


if __name__ == '__main__':
    app.run(main)