import argparse
import logging
import os
import numpy as np
import torch
import neptune

import utils
from hyperparameters import get_hyperparameters
from common import make_env, create_folders
from sac import SAC

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure the logs directory exists
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir, "training.log")),
        logging.StreamHandler()
    ]
)

def setup_neptune(project_name, api_token, parameters):
    """
    Initialize Neptune run and set parameters.

    Args:
        project_name (str): Name of the Neptune project.
        api_token (str): API token for Neptune.
        parameters (dict): Parameters to log in Neptune.

    Returns:
        Neptune run object.
    """
    run = neptune.init_run(
        project=project_name,
        api_token=api_token,
    )
    run["parameters"] = parameters
    return run

def setup_environment(env_name, seed):
    """
    Set up the environment.

    Args:
        env_name (str): Name of the environment.
        seed (int): Random seed for reproducibility.

    Returns:
        Environment object.
    """
    return make_env(env_name, seed)

def evaluate_policy(policy, env_name, seed):
    """
    Evaluate the policy and return the average reward.

    Args:
        policy: The policy to be evaluated.
        env_name (str): Name of the environment.
        seed (int): Random seed for reproducibility.

    Returns:
        float: Average reward over the evaluation episodes.
    """
    eval_env = make_env(env_name, seed + 100)
    rewards = 0
    for _ in range(10):
        eval_state, eval_done = eval_env.reset(), False
        while not eval_done:
            eval_action = policy.select_action(eval_state, evaluate=True)
            eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
            eval_state = eval_next_state
            rewards += eval_reward
            if eval_done:
                break
    return rewards / 10

def train(seed=0, env_name='InvertedPendulum-v2', automatic_entropy_tuning=True):
    """
    Main function to train the policy. Model is trained and evaluated inside.

    Args:
        seed (int): Random seed for reproducibility.
        env_name (str): Name of the environment.
        automatic_entropy_tuning (bool): Whether to automatically tune entropy.
    """

    project_name = ""  # Replace with your project name
    api_token = ""  # Replace with your Neptune API token


    parameters = {
        'type': "SAC",
        'env_name': env_name,
        'seed': seed,
        'automatic_entropy_tuning': automatic_entropy_tuning,
    }
    run = setup_neptune(project_name, api_token, parameters)

    hy = get_hyperparameters(env_name, 'SAC')

    augment_type = "SAC"
    arguments = [augment_type, env_name, seed, automatic_entropy_tuning]
    file_name = '_'.join([str(x) for x in arguments])

    logging.info("---------------------------------------")
    logging.info(f"Env: {env_name}, Seed: {seed}")
    logging.info("---------------------------------------")

    create_folders()
    env = setup_environment(env_name, seed)

    torch.manual_seed(seed)
    np.random.seed(seed)

    max_timesteps = hy['max_timesteps']
    eval_freq = hy['eval_freq']
    start_timesteps = hy['start_timesteps']

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    kwargs = {
        "gamma": hy['discount'],
        "tau": hy['tau'],
        "alpha": hy['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hy['hidden_size'],
        "target_update_interval": hy['target_update_interval'],
        "automatic_entropy_tuning": automatic_entropy_tuning,
        "lr": hy['lr'],
    }

    policy = SAC(state_dim, env.action_space, **kwargs)
    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    state, done = env.reset(), False
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    max_episode_timestep = hy['max_episode_steps']
    updates = 0
    best_performance = -10000
    evaluations = []

    t = 0
    while t < int(max_timesteps):
        # Select action randomly or according to policy
        if t < start_timesteps:
            action = env.action_space.sample()  # Sample random action
        else:
            action = policy.select_action(state)  # Sample action from policy

        next_state, reward, done, _ = env.step(action)
        episode_timesteps += 1
        done_bool = float(done) if episode_timesteps < max_episode_timestep else 0
        replay_buffer.add(state, action, next_state, reward, done_bool)
        state = next_state
        episode_reward += reward
        t += 1

        # Evaluate episode
        if (t + 1) % eval_freq == 0:
            avg_reward = evaluate_policy(policy, env_name, seed)
            evaluations.append(avg_reward)
            logging.info(f" --------------- Evaluation reward {avg_reward:.3f}")
            run['avg_reward'].log(avg_reward)
            np.save(f"./results/{file_name}", evaluations)

            if best_performance <= avg_reward:
                best_performance = avg_reward
                run['best_reward'].log(best_performance)
                policy.save_checkpoint(f"./models/{file_name}_best")

        # Train agent after collecting sufficient data
        if replay_buffer.size >= hy["batch_size"]:
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = policy.update_parameters(replay_buffer, hy["batch_size"], updates)
            updates += 1
            run['critic1'].log(critic_1_loss)
            run['critic2'].log(critic_2_loss)
            run['policy'].log(policy_loss)
            run['entropy'].log(ent_loss)
            run['alpha'].log(alpha)

        if done:
            logging.info(f"Total T: {t} Episode Num: {episode_num+1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
            state, done = env.reset(), False
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1

    policy.save_checkpoint(f"./models/{file_name}_final")
    run.stop()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="InvertedPendulum-v2", help="Environment name")
    parser.add_argument("--seed", default=0, type=int, help="Sets Gym, PyTorch and Numpy seeds")
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G', help='Automatically adjust α (default: False)')

    args = vars(parser.parse_args())
    logging.info('Command-line argument values:')
    for key, value in args.items():
        logging.info(f'- {key} : {value}')

    train(**args)
