import argparse
import logging
import os

import numpy as np
import pandas as pd
import torch
from openpyxl import Workbook

from common import create_folders, make_env
from hyperparameters import get_hyperparameters
from sac import SACGRU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure the logs directory exists
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir, "evaluation.log")),
        logging.StreamHandler()
    ]
)

def create_excel_if_not_exists(file_path):
    """
    Create an Excel file if it does not already exist.

    Args:
        file_path (str): Path to the Excel file.
    """
    if not os.path.exists(file_path):
        workbook = Workbook()
        sheet = workbook.active
        sheet.title = "Sheet"
        workbook.save(filename=file_path)
        logging.info(f"New workbook created and saved as {file_path}")
    else:
        logging.info(f"Workbook already exists at {file_path}")


def setup_environment(env_name, seed):
    """
    Set up the environment based on the type of environment.

    Args:
        env_name (str): Name of the environment.
        seed (int): Random seed for reproducibility.

    Returns:
        Environment object.
    """
    return make_env(env_name, seed)


def evaluate_policy(policy, eval_env, steps_list, action_dim):
    """
    Evaluate the policy over a list of steps and return the results.

    Args:
        policy: The policy to be evaluated.
        eval_env: The environment for evaluation.
        steps_list (list): List of steps to evaluate.
        action_dim (int): Dimension of the action space.

    Returns:
        List of tuples containing step size and average reward.
    """
    results = []
    for s in steps_list:
        rewards = 0
        logging.info(f"Evaluating with steps: {s}")
        for _ in range(10):
            eval_state, eval_done = eval_env.reset(), False
            eval_episode_timesteps = 0
            eval_prev_action = torch.zeros(action_dim)
            while not eval_done:
                sample_args = [torch.FloatTensor(eval_state.reshape(1, -1)).to(device),
                               torch.FloatTensor(eval_prev_action.reshape(1, -1)).to(device), s, True]
                _, _, eval_actions = policy.policy.sample(*sample_args)
                eval_actions = eval_actions.cpu().data.numpy()[0]

                for eval_ps in range(s):
                    eval_action = eval_actions[eval_ps] if s > 1 else eval_actions
                    eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
                    eval_state = eval_next_state
                    eval_prev_action = eval_action
                    eval_episode_timesteps += 1
                    rewards += eval_reward
                    if eval_done:
                        break
        avg_reward = rewards / 10
        logging.info(f" --------------- Evaluation reward {avg_reward:.3f}")
        results.append((s, avg_reward))
    return results


def eval(seed=0, env_name='InvertedPendulum-v2', automatic_entropy_tuning=True, steps=2, actor_update_frequency=1):
    """
    Main function to evaluate the policy. Model is trained and evaluated inside.

    Args:
        seed (int): Random seed for reproducibility.
        env_name (str): Name of the environment.
        automatic_entropy_tuning (bool): Whether to automatically tune entropy.
        steps (int): Number of steps to plan ahead.
        actor_update_frequency (int): Frequency of actor updates.
    """
    hy = get_hyperparameters(env_name, 'SAC')
    file_name = '_'.join([str(x) for x in [
        "SAC_GRU", env_name, seed, automatic_entropy_tuning, steps, actor_update_frequency]])

    logging.info("---------------------------------------")
    logging.info(f"Env: {env_name}, Seed: {seed}")
    logging.info("---------------------------------------")

    create_folders()

    env = setup_environment(env_name, seed)

    torch.manual_seed(seed)
    np.random.seed(seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    policy_kwargs = {
        "gamma": hy['discount'],
        "tau": hy['tau'],
        "alpha": hy['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hy['hidden_size'],
        "target_update_interval": hy['target_update_interval'],
        "automatic_entropy_tuning": automatic_entropy_tuning,
        "lr": hy['lr'],
        "steps": steps,
        'actor_update_frequency': actor_update_frequency
    }

    policy = SACGRU(state_dim, env.action_space, **policy_kwargs)
    policy.load_checkpoint(f"./models/{file_name}_best")

    eval_env = setup_environment(env_name, seed + 100)

    steps_list = [x for x in range(2, 32, 2)]
    steps_list.insert(0, 1)
    results = evaluate_policy(policy, eval_env, steps_list, action_dim)

    for s, avg_reward in results:
        df1 = pd.DataFrame({
            'seed': [seed],
            'reward': [avg_reward],
            'env_name': [env_name],
            'train_steps': [steps],
            'steps': [s],
            'type': ['increasing_steps']
        })
        df1.to_csv('eval.csv', mode='a', index=False, header=None)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="InvertedPendulum-v2", help="Environment name")
    parser.add_argument("--seed", default=0, type=int, help="Sets Gym, PyTorch and Numpy seeds")
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, help='Automatically adjust α (default: False)')
    parser.add_argument("--steps", default=2, type=int, help="Number of steps to plan ahead")
    parser.add_argument("--actor_update_frequency", default=1, type=int, help="Actor update frequency")

    args = vars(parser.parse_args())
    logging.info('Command-line argument values:')
    for key, value in args.items():
        logging.info(f'- {key} : {value}')

    eval(**args)
