import os
import pathlib
import random
from typing import Dict, Optional

import hydra
import numpy as np
import torch
import torch.autograd
import wandb
from omegaconf import DictConfig, OmegaConf
from torchinfo import summary
from tqdm import tqdm

from lambda_ac.env.env import setup_environment
from lambda_ac.nn.common import EnsembleLinearLayer
from lambda_ac.replay.env_storage import make_env_storage
from lambda_ac.rl_types import EncoderModelBasedActorCriticAgent, ExplorationScheduler
from lambda_ac.util.hydra_util import fix_env_config
from lambda_ac.util.rl_util import RandomAgent
from lambda_ac.util.schedulers import LinearIncrease


@hydra.main(version_base=None, config_path="../config", config_name="main")
def main(args):
    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    batch_size = max(args.model_batch_size, args.rl_batch_size)

    # enable debug gradient checks
    if args.debug:
        torch.autograd.set_detect_anomaly(True)

    # create environment and update config with dimensions
    env = setup_environment(
        args.env.env_type,
        args.env.env_id,
        args.env.n_envs,
        args.device,
        args.seed,
        **args.env.kwargs,
    )
    eval_env = setup_environment(
        args.env.env_type,
        args.env.env_id,
        args.env.n_envs,
        args.device,
        args.seed,
        **args.env.kwargs,
    )
    env_storage = make_env_storage(args.env.env_type, env, eval_env, args)
    args = fix_env_config(args, env)

    # create wandb process or attach to existing on reset
    init_or_resume_wandb_run(pathlib.Path("wandb_id.txt"), args)

    # generate agent
    agent: EncoderModelBasedActorCriticAgent = hydra.utils.instantiate(args.agent)
    summary(agent.actor, depth=10)
    summary(agent.critic, depth=10)
    summary(agent.model, depth=10)

    # generate model data ratio scheduler
    model_data_ratio = LinearIncrease(
        args.start_model_usage_steps,
        args.minimum_model_data_ratio,
        args.maximum_model_data_ratio,
        args.full_model_usage_steps,
    )

    # check for existing run to resume
    if os.path.exists("train.meta"):
        agent.load_checkpoint(args.env.env_name)
        env_storage.load(args.env.env_name)
        with open("train.meta", "r") as f:
            total_numsteps, updates = (int(l) for l in f.readlines())
            i_episode = total_numsteps // 1000
            model_data_ratio.set(updates)
    else:
        total_numsteps = 0
        updates = 0
        i_episode = 0
    pbar = tqdm(total=args.num_steps)
    pbar.update(total_numsteps)
    exploration_schedule: ExplorationScheduler = hydra.utils.instantiate(
        args.exploration_schedule
    )
    exploration_schedule.set(total_numsteps)

    ####
    # Training
    ####
    steps_since_last_checkpoint = 0
    while total_numsteps < args.num_steps:
        done = False
        episode_steps = 0
        episode_reward = 0
        while not done:
            if total_numsteps == batch_size:
                env_storage.init_memory()
            if total_numsteps >= batch_size and total_numsteps > args.start_train:
                # Update model for one step
                stats = {}
                agent.model_data_ratio = model_data_ratio()
                model_data_ratio.step()
                # Update parameters of all the networks
                if args.update_model:
                    model_stats = agent.update_model(
                        env_storage.get_batch(args.model_batch_size), updates
                    )
                else:
                    model_stats = {}
                critic_stats = agent.update_critic(
                    env_storage.get_batch(args.rl_batch_size), updates
                )
                actor_stats = agent.update_actor(
                    env_storage.get_batch(args.rl_batch_size), updates
                )
                # agent.update_priorities(
                #     env_storage.get_batch(args.rl_batch_size),
                #     env_storage.buffer,
                #     updates,
                # )
                stats = {**model_stats, **actor_stats, **critic_stats}
                # bookkeeping
                updates += 1
                if total_numsteps % args.log_freq == 0:  # prevent wandb ultra spamming
                    wandb.log(stats, step=total_numsteps)
            # env interaction
            done, episode_reward = env_storage.step(agent, exploration_schedule)

            ####
            # Evaluation
            ####
            if total_numsteps % args.eval_freq == 0 and args.eval is True:
                episodes = 20
                avg_reward = env_storage.eval(agent, episodes)
                wandb.log({"test_stats/avg_reward": avg_reward}, step=total_numsteps)
                print("----------------------------------------")
                print("Test Episodes: {}, Avg. Reward: {}".format(episodes, avg_reward))
                print("----------------------------------------")

            exploration_schedule.step()
            episode_steps += 1
            total_numsteps += 1
            pbar.update(1)

            if total_numsteps > args.num_steps:
                break

        # end of episode bookkeeping
        i_episode += 1
        # env_storage.buffer.weight_update()

        wandb.log({"train_stats/avg_reward": episode_reward}, step=total_numsteps)
        wandb.log({"train_stats/steps": episode_steps}, step=total_numsteps)

        # checkpointing
        # check prevents checkpointing unnecessarily often, which can cause major slowdowns
        # checkpointing only after done ensures that the buffer is not filled with wrong transitions after reload
        if steps_since_last_checkpoint > 5000:
            steps_since_last_checkpoint = 0
            agent.save_checkpoint(args.env.env_name)
            env_storage.save(args.env.env_name)
            with open("train.meta", "w") as f:
                f.write(str(total_numsteps))
                f.write("\n")
                f.write(str(updates))
        else:
            steps_since_last_checkpoint += episode_steps


def init_or_resume_wandb_run(
    wandb_id_file_path: pathlib.Path,
    config: Optional[DictConfig] = None,
):
    """Detect the run id if it exists and resume
    from there, otherwise write the run id to file.

    Returns the config, if it's not None it will also update it first

    NOTE:
        Make sure that wandb_id_file_path.parent exists before calling this function
    """
    args_dict: Dict = OmegaConf.to_container(
        config, resolve=True, throw_on_missing=True
    )  # type: ignore
    # if the run_id was previously saved, resume from there
    if wandb_id_file_path.exists():
        print("Trying to resume")
        resume_id = wandb_id_file_path.read_text()
        run = wandb.init(
            project="lambda_ac",
            entity=f"lambda_ac_svg",
            config=args_dict,
            resume=resume_id,
        )
    else:
        # if the run_id doesn't exist, then create a new run
        # and write the run id the file
        print("Starting new")
        run = wandb.init(
            project="lambda_ac",
            entity=f"lambda_ac_svg",
            config=args_dict,
        )
        wandb_id_file_path.write_text(str(run.id))


if __name__ == "__main__":
    main()
