import os
import gym
import csv
import warnings
import argparse
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import d4rl
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from flax.serialization import to_bytes
import logging

key = jax.random.PRNGKey(0)
warnings.filterwarnings("ignore")
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

class BehaviorCloning(nn.Module):
    state_dim: int
    action_dim: int
    hidden_dim: int

    def setup(self):
        self.fc1 = nn.Dense(self.hidden_dim)
        self.fc2 = nn.Dense(self.hidden_dim)
        self.fc3 = nn.Dense(self.action_dim)

    def __call__(self, state):
        x = nn.relu(self.fc1(state))
        x = nn.relu(self.fc2(x))
        return self.fc3(x)

def save_checkpoint_alternative(checkpoint_dir, state, step):
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{step}.ckpt")
    with open(checkpoint_path, 'wb') as f:
        f.write(to_bytes(state))

def create_train_state(rng, model, learning_rate, batch_size, state_dim):
    dummy_input = jax.random.normal(rng, (batch_size, state_dim))
    params = model.init(rng, dummy_input)
    
    tx = optax.adamw(learning_rate)
    
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


def mse_loss(pred, target):
    return jnp.mean((pred - target) ** 2)

@jax.jit
def train_step(state, batch_states, batch_actions):
    def loss_fn(params):
        predicted_action = state.apply_fn({'params': params["params"]}, batch_states)
        loss = mse_loss(predicted_action, batch_actions)
        return loss

    grads = jax.grad(loss_fn)(state.params)
    new_state = state.apply_gradients(grads=grads)
    loss_value = loss_fn(state.params)
    return new_state, loss_value

def shuffle_data(key, num_data_points):
    return jax.random.permutation(key, num_data_points)

def train(args):
    rng = jax.random.PRNGKey(0)

    # Environment and Dataset
    env = gym.make(args.env_name)
    dataset = env.get_dataset()

    states = jnp.array(dataset['observations'])
    actions = jnp.array(dataset['actions'])

    train_states, eval_states, train_actions, eval_actions = train_test_split(
        states, actions, test_size=0.0001, random_state=42
    )

    train_states = jax.device_put(train_states)
    train_actions = jax.device_put(train_actions)
    eval_states = jax.device_put(eval_states)
    eval_actions = jax.device_put(eval_actions)

    model = BehaviorCloning(state_dim=states.shape[1], action_dim=actions.shape[1], hidden_dim=args.hidden_size)
    state = create_train_state(rng, model, args.lr, batch_size=args.batch_size, state_dim=states.shape[1])

    total_steps = 0
    max_steps = args.max_steps
    num_data_points = len(train_states)
    batch_size = args.batch_size
    num_batches_per_epoch = num_data_points // batch_size

    model_dir = os.path.join("logs", args.run_name, args.env_name, "models")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    train_csv_file = os.path.join("logs", args.run_name, args.env_name, "train_metrics.csv")
    eval_csv_file = os.path.join("logs", args.run_name, args.env_name, "eval_metrics.csv")

    with open(train_csv_file, 'w', newline='') as train_file:
        train_writer = csv.writer(train_file)
        train_writer.writerow(["Step", "Train Behavior Cloning Loss"])

    with open(eval_csv_file, 'w', newline='') as eval_file:
        eval_writer = csv.writer(eval_file)
        eval_writer.writerow(["Step", "Eval Behavior Cloning Loss"])

    checkpoint_dir = os.path.join("logs", args.run_name, args.env_name, "checkpoints")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    shuffle_idx = shuffle_data(rng, num_data_points)
    batch_idx = 0

    with tqdm(total=max_steps, desc="Training Progress", unit="step", ncols=100) as pbar:
        while total_steps <= max_steps:

            if batch_idx == num_batches_per_epoch:
                shuffle_idx = shuffle_data(rng, num_data_points)
                batch_idx = 0

            start = batch_idx * batch_size
            end = start + batch_size
            batch_states = train_states[shuffle_idx[start:end]]
            batch_actions = train_actions[shuffle_idx[start:end]]

            state, loss_value = train_step(state, batch_states, batch_actions)

            with open(train_csv_file, 'a', newline='') as train_file:
                train_writer = csv.writer(train_file)
                train_writer.writerow([total_steps, loss_value])

            pbar.update(1)
            total_steps += 1
            batch_idx += 1

            if total_steps >= max_steps:
                save_checkpoint_alternative(checkpoint_dir, state, total_steps)
                break

def setup_logging(run_name):
    if not os.path.exists('logs'):
        os.makedirs('logs')

def launch():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='halfcheetah-medium-v2', help='Name of the environment')
    parser.add_argument('--max_steps', type=int, default=2000000, help='Number of steps to train for')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
    parser.add_argument('--hidden_size', type=int, default=256, help='Hidden size for MLP')
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--run_name', type=str, default='bc', help='Name of the run')
    parser.add_argument('--device', type=str, default='gpu', help='Device to run on (cpu or gpu)')
    args = parser.parse_args()

    if args.device == 'gpu':
        jax.config.update("jax_platform_name", "gpu")
    else:
        jax.config.update("jax_platform_name", "cpu")

    train(args)

if __name__ == '__main__':
    launch()
