import random
import hydra
import torch
import numpy as np
import logging

import os

from datetime import timedelta
from rich.pretty import pretty_repr
from timeit import default_timer as timer
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig

from utils import utils
from utils.evaluation import evaluate_episodic
from utils.utils import prep_cfg_for_db
from experiment import ExperimentManager, Metric
from environments.factory import get_env
from models.actor import Actor
from models.critic import get_critic
from models.replay_buffers.factory import get_buffer
from agents.factory import get_agent
from environments.get_dataset import load_dataset, training_set_construction


log = logging.getLogger(__name__)

@hydra.main(version_base="1.3", config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
    start = timer()
    if HydraConfig.get().mode.value == 2:  # check whether its a sweep
        cfg.run += HydraConfig.get().job.num
        log.info(f'Running sweep... Run ID: {cfg.run}')
    log.info(f"Output directory  : \
             {hydra.core.hydra_config.HydraConfig.get().runtime.output_dir}")
    cfg.agent.actor.optimizer.lr = cfg.agent.actor.optimizer.critic_lr_multiplier * \
                                    cfg.agent.critic.optimizer.lr
    flattened_cfg = prep_cfg_for_db(OmegaConf.to_container(cfg, resolve=True),
                                    to_remove=["schema", "db"])
    log.info(pretty_repr(flattened_cfg))

    exp = ExperimentManager(cfg.db_name, flattened_cfg, cfg.db_prefix, cfg.db)
    tables = {}
    for table_name in list(cfg.schema.keys()):
        columns = cfg.schema[table_name].columns
        primary_keys = cfg.schema[table_name].primary_keys
        tables[table_name] = Metric(table_name, columns, primary_keys, exp)
    torch.set_num_threads(cfg.n_threads)
    utils.set_seed(cfg.seed)

    env = get_env(cfg.env, seed=cfg.seed)
    test_env = get_env(cfg.env, seed=cfg.seed)

    actor = Actor(cfg.agent.actor, env.env, cfg.agent.store_old_policy, cfg.device)
    critic = get_critic(cfg.agent.critic, env.env, cfg.device)
    """
    offline data needs to be loaded from the d4rl datasets and stored into the buffer
    """
    offline_data = load_dataset(cfg.env.name, cfg.env.dataset, cfg, cfg.seed)
    trainset = training_set_construction(offline_data)
    train_s, train_a, train_r, train_ns, train_t = trainset
    buffer = get_buffer(cfg.agent.buffer, cfg.seed, env.env, cfg.device)
    for idx in range(len(train_s)):
        """flip the terminating conditions so in the update only mask_batch needs to be used"""
        buffer.push(state=train_s[idx], action=train_a[idx], reward=train_r[idx], next_state=train_ns[idx], done=1-train_t[idx])

    agent = get_agent(cfg.agent, False, cfg.device, env.env, actor, critic, buffer)
    """
    we sample 10 states for evaluating policy parameters like location and scale
    parameters from algorithms with the same seed are compared, e.g. seed=0
    """
    # policy_param_states, _, _, _, _ = buffer.sample(batch_size=10)

    obs = env.reset(seed=cfg.seed)
    test_env.reset(seed=cfg.seed)



    step = 0
    all_rewards = []
    all_normalized_rewards = []
    episode = 0
    episodic_reward = 0

    for step in range(cfg.steps):

        if not cfg.offline:
            """learn from online Mujoco environments"""
            if step < cfg.learning_starts:
                action = env.action_space.sample()
            else:
                action = agent.act(obs)
            obs_next, reward, terminated, info = env.step(action)
            buffer.push(obs, action, reward, obs_next, 1-terminated)
            if step > cfg.learning_starts:
                agent.update_critic()
                agent.update_actor()
            obs = obs_next
            episodic_reward += reward

            if terminated:
                obs = env.reset(seed=cfg.seed)
                log.info(f'step: {step} \t \t episode: {episode}, \
                        reward: {episodic_reward}')
                episode += 1
                episodic_reward = 0

        else:
            """learn from offline datasets"""
            losses = agent.update_critic()
            # actor_loss = agent.update_actor()
            actor_loss = agent.update_actor()
            q_loss, v_loss = losses


        # if step % cfg.param_save_steps == 0 and step < cfg.param_stop_steps and cfg.seed == 0:
        #     loc, scale, dof = agent.get_policy_params(policy_param_states)
        #     # (batch_size, action_dim)
        #     # print(f"loc: {loc.shape}, scale: {scale.shape}, dof: {dof.shape}")
        #     tables["policy"].add_data(
        #         [
        #             cfg.run,
        #             step,
        #             episode,
        #             loc.reshape(-1).tolist(),
        #             scale.reshape(-1).tolist(),
        #             dof.reshape(-1).tolist(),
        #         ]
        #     )
            
        if step % cfg.evaluation_steps == 0:
            log.info(f"Step {step}, \t training actor loss: {actor_loss:.4f}, \t critic q loss: {q_loss:.4f} v loss: {v_loss:.4f}")
            episode += 0
            mean_reward, std_reward, mean_normalized, std_normalized = evaluate_episodic(test_env,
                                                        agent,
                                                        cfg.evaluation_episodes,
                                                        cfg.seed,
                                                        step,
                                                        cfg.timeout,
                                                        )
            tables["returns"].add_data(
                [
                    cfg.run,
                    step,
                    episode,
                    mean_reward,
                    mean_normalized,
                ]
            )
            all_rewards.append(mean_reward)
            all_normalized_rewards.append(mean_normalized)

        if step % 10000 == 0:
            tables["returns"].commit_to_database()
            # tables["policy"].commit_to_database()

    total_time = timedelta(seconds=timer() - start).seconds / 60
    auc_10 = float(np.mean(all_rewards[-int(len(all_rewards)*0.1):]))
    auc_50 = float(np.mean(all_rewards[-int(len(all_rewards)*0.5):]))
    auc_100 = float(np.mean(all_rewards))
    norm_auc_10 = float(np.mean(all_normalized_rewards[-int(len(all_normalized_rewards)*0.1):]))
    norm_auc_50 = float(np.mean(all_normalized_rewards[-int(len(all_normalized_rewards)*0.5):]))
    norm_auc_100 = float(np.mean(all_normalized_rewards))
    tables["summary"].add_data(
        [
            cfg.run,
            step,
            episode,
            auc_100,
            norm_auc_100,
            auc_50,
            norm_auc_50,
            auc_10,
            norm_auc_10,
            total_time
        ]
    )
    tables["returns"].commit_to_database()
    tables["summary"].commit_to_database()
    tables["policy"].commit_to_database()
    log.info(f'Total time taken: {total_time}  minutes')

if __name__ == "__main__":
    main()
