import argparse
import logging
import os
import numpy as np
import torch
import neptune

import utils
from hyperparameters import get_hyperparameters
from common import make_env, create_folders
from sac import SACGRU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure the logs directory exists
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir, "training.log")),
        logging.StreamHandler()
    ]
)

def setup_neptune(project_name, api_token, parameters):
    """
    Initialize Neptune run and set parameters.

    Args:
        project_name (str): Name of the Neptune project.
        api_token (str): API token for Neptune.
        parameters (dict): Parameters to log in Neptune.

    Returns:
        Neptune run object.
    """
    run = neptune.init_run(
        project=project_name,
        api_token=api_token,
    )
    run["parameters"] = parameters
    return run

def setup_environment(env_name, seed):
    """
    Set up the environment.

    Args:
        env_name (str): Name of the environment.
        seed (int): Random seed for reproducibility.

    Returns:
        Environment object.
    """
    return make_env(env_name, seed)

def evaluate_policy(policy, env_name, seed, steps):
    """
    Evaluate the policy and return the average reward.

    Args:
        policy: The policy to be evaluated.
        env_name (str): Name of the environment.
        seed (int): Random seed for reproducibility.
        steps (int): Number of steps to plan ahead.

    Returns:
        float: Average reward over the evaluation episodes.
    """
    eval_env = setup_environment(env_name, seed + 100)
    rewards = 0
    for _ in range(10):
        eval_state, eval_done = eval_env.reset(), False
        eval_prev_action = torch.zeros(policy.action_dim)
        while not eval_done:
            _, _, eval_actions = policy.policy.sample(
                torch.FloatTensor(eval_state.reshape(1, -1)).to(device),
                torch.FloatTensor(eval_prev_action.reshape(1, -1)).to(device),
                steps, evaluate=True
            )
            eval_actions = eval_actions.cpu().data.numpy()[0]
            for eval_ps in range(steps):
                eval_action = eval_actions[eval_ps]
                eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
                eval_state = eval_next_state
                eval_prev_action = eval_action
                rewards += eval_reward
                if eval_done:
                    break
    return rewards / 10

def train(seed=0, env_name='InvertedPendulum-v2', automatic_entropy_tuning=True, steps=2, actor_update_frequency=1, actor_type='GRU'):
    """
    Main function to train the policy. Model is trained and evaluated inside.

    Args:
        seed (int): Random seed for reproducibility.
        env_name (str): Name of the environment.
        automatic_entropy_tuning (bool): Whether to automatically tune entropy.
        steps (int): Number of steps to plan ahead.
        actor_update_frequency (int): Frequency of actor updates.
    """
    project_name = ""  # Replace with your project name
    api_token = ""  # Replace with your Neptune API token

    parameters = {
        'type': "SAC_GRU",
        'env_name': env_name,
        'seed': seed,
        'automatic_entropy_tuning': automatic_entropy_tuning,
        'steps': steps,
        'actor_update_frequency': actor_update_frequency,
        'actor_type': actor_type
    }
    run = setup_neptune(project_name, api_token, parameters)
    run["steps"].log(steps)

    hy = get_hyperparameters(env_name, 'SAC')
    augment_type = "SAC_GRU"
    arguments = [augment_type, env_name, seed, automatic_entropy_tuning, steps, actor_update_frequency]
    file_name = '_'.join([str(x) for x in arguments])

    logging.info("---------------------------------------")
    logging.info(f"Env: {env_name}, Seed: {seed}")
    logging.info("---------------------------------------")

    create_folders()
    env = setup_environment(env_name, seed)

    torch.manual_seed(seed)
    np.random.seed(seed)

    max_timesteps = hy['max_timesteps']
    eval_freq = hy['eval_freq']
    start_timesteps = hy['start_timesteps']

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    kwargs = {
        "gamma": hy['discount'],
        "tau": hy['tau'],
        "alpha": hy['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hy['hidden_size'],
        "target_update_interval": hy['target_update_interval'],
        "automatic_entropy_tuning": automatic_entropy_tuning,
        "lr": hy['lr'],
        "steps": steps,
        'actor_update_frequency': actor_update_frequency,
        'actor_type': actor_type
    }

    policy = SACGRU(state_dim, env.action_space, **kwargs)
    replay_buffer = utils.TAACReplayBuffer(state_dim, action_dim)
    state, done = env.reset(), False
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    max_episode_timestep = env._max_episode_steps
    updates = 0
    best_performance = -10000
    evaluations = []
    previous_action = torch.zeros(action_dim)

    t = 0
    while t < int(max_timesteps):
        actions = policy.select_action(state, previous_action, steps)
        for ps in range(steps):
            action = env.action_space.sample() if t < start_timesteps else actions[ps]

            next_state, reward, done, _ = env.step(action)
            episode_timesteps += 1
            done_bool = float(done) if episode_timesteps < max_episode_timestep else 0
            replay_buffer.add(state, action, previous_action, next_state, reward, done_bool)
            state = next_state
            previous_action = action
            episode_reward += reward
            t += 1

            if (t + 1) % eval_freq == 0:
                avg_reward = evaluate_policy(policy, env_name, seed, steps)
                evaluations.append(avg_reward)
                logging.info(f" --------------- Evaluation reward {avg_reward:.3f}")
                run['avg_reward'].log(avg_reward)
                np.save(f"./results/{file_name}", evaluations)

                if best_performance <= avg_reward:
                    best_performance = avg_reward
                    run['best_reward'].log(best_performance)
                    policy.save_checkpoint(f"./models/{file_name}_best")

            if replay_buffer.size >= hy["batch_size"]:
                if updates % actor_update_frequency == 0:
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, model_loss = policy.update_parameters(replay_buffer, hy["batch_size"], updates)
                    run['critic1'].log(critic_1_loss)
                    run['critic2'].log(critic_2_loss)
                    run['policy'].log(policy_loss)
                    run['entropy'].log(ent_loss)
                    run['alpha'].log(alpha)
                    run['model'].log(model_loss)
                else:
                    critic_1_loss, critic_2_loss, model_loss = policy.update_parameters(replay_buffer, hy["batch_size"], updates)
                    run['critic1'].log(critic_1_loss)
                    run['critic2'].log(critic_2_loss)
                    run['model'].log(model_loss)
                updates += 1

            if done:
                logging.info(f"Total T: {t} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
                state, done = env.reset(), False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1
                previous_action = torch.zeros(action_dim)
                break

    policy.save_checkpoint(f"./models/{file_name}_final")
    run.stop()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="InvertedPendulum-v2", help="Environment name")
    parser.add_argument("--seed", default=0, type=int, help="Sets Gym, PyTorch and Numpy seeds")
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G', help='Automatically adjust α (default: False)')
    parser.add_argument("--steps", default=2, type=int, help="Number of steps to plan ahead")
    parser.add_argument("--actor_update_frequency", default=1, type=int, help="Actor update frequency")
    parser.add_argument("--actor_type", default='GRU', help="Actor recurrent unit type")

    args = vars(parser.parse_args())
    logging.info('Command-line argument values:')
    for key, value in args.items():
        logging.info(f'- {key} : {value}')

    train(**args)
