import argparse
import functools
import itertools
import os
import random
import warnings
from collections import defaultdict
from collections.abc import Sequence, Iterable
from pathlib import Path
from pprint import pprint

import gym
import numpy as np
import torch
import wandb
from tqdm.auto import trange, tqdm

from collector.replay_buffer.episode_replay import TrajectoryReplayBuffer
from collector.strategy import OnPolicyCollectionStrategy, OffPolicyCollectionStrategy, record_episode
from envs.wrappers.episode_count_wrapper import EpisodeCountWrapper
from helpers import init_model, init_envs
from logger.ratemap import log_ratemaps
from logger.video import log_videos
from policy.base import BasePolicy

# from functools import partialmethod
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)


os.environ["WANDB_API_KEY"] = "ENTER WANDB API KEY HERE"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"
WANDB_GROUP = "ENTER WANDB GROUP HERE"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CONFIG = {}

# torch.set_printoptions(precision=10, linewidth=90, sci_mode=False)
np.set_printoptions(precision=4, linewidth=90, sign=' ', suppress=True)


def parse_args(argv=None):
    parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
    parser.add_argument('--env_reward_wrong', type=float)
    parser.add_argument('--env_reward_correct', type=float)
    parser.add_argument('--env_reward_backward', type=float)
    parser.add_argument('--env_reward_wall', type=float)
    parser.add_argument('--env_reward_backwall', type=float)
    parser.add_argument('--env_seq_len', type=int)
    parser.add_argument('--env_max_steps', type=int)

    group_pixel = parser.add_mutually_exclusive_group()
    group_pixel.add_argument('--pixel_output', action="store_true", dest="env_pixel_output")
    group_pixel.add_argument('--no-pixel_output', action="store_false", dest="env_pixel_output")

    group_decaying = parser.add_mutually_exclusive_group()
    group_decaying.add_argument('--decaying_walls', action="store_true", dest="env_decaying_walls")
    group_decaying.add_argument('--no-decaying_walls', action="store_false", dest="env_decaying_walls")

    parser.add_argument('--env_decaying_rate_walls', type=float)

    group_end_wall = parser.add_mutually_exclusive_group()
    group_end_wall.add_argument('--end_wall', action="store_true", dest="env_flag_end_wall")
    group_end_wall.add_argument('--no-end_wall', action="store_false", dest="env_flag_end_wall")

    group_autoencoder = parser.add_mutually_exclusive_group()
    group_autoencoder.add_argument('--encode_obs_using_autoencoder', action="store_true", dest="env_encode_obs_using_autoencoder")
    group_autoencoder.add_argument('--no-encode_obs_using_autoencoder', action="store_false", dest="env_encode_obs_using_autoencoder")

    parser.add_argument('--encoder_type', choices=["conv", "dense"])
    parser.add_argument('--encoder_activation_penalty_norm_p', type=int)
    parser.add_argument('--encoder_activation_penalty_weight', type=float)

    parser.add_argument('--memory_type',
                        choices=[
                            "rnn",
                            "rnn_frozen",
                            "rnn_relu",
                            "rnn_relu_frozen",
                            "lstm",
                            "lstm_frozen",
                            "gru",
                            "gru_frozen",
                            "F",
                            "F_sub",
                            "sith",
                            "sith_sub_sum",
                            "sith_sub_nosum",
                            "sith_subonly",
                            "sith_subonly_nosum",
                            "sith_subonly_sumonly"
                        ])
    parser.add_argument('--memory_activation_penalty_norm_p', type=int)
    parser.add_argument('--memory_activation_penalty_weight', type=float)
    parser.add_argument('--memory_hidden_size', type=int)

    parser.add_argument('--add_z_skip', action="store_true")
    parser.add_argument('--add_outer', action="store_true")

    parser.add_argument('--gamma', type=float)

    parser.add_argument('--learning_rate', type=float)

    parser.add_argument('--disable_wandb', action="store_true", default=False)
    parser.add_argument('--wandb_project', type=str, default=None)
    parser.add_argument('--wandb_entity', type=str, default=None)
    parser.add_argument('--wandb_group', type=str, default=None)

    parser.add_argument('--disable_tqdm', action="store_true", default=False)

    parser.add_argument('--seed', type=int, default=None)

    subparsers = parser.add_subparsers(dest="rl_method", required=False)

    # dqn params
    parser_dqn = subparsers.add_parser('dqn', argument_default=argparse.SUPPRESS)
    parser_dqn.add_argument('--dqn_weight_penalty_norm_p', type=int)
    parser_dqn.add_argument('--dqn_weight_penalty_weight', type=float)

    # a2c params
    parser_a2c = subparsers.add_parser('a2c', argument_default=argparse.SUPPRESS)
    parser_a2c.add_argument('--actor_weight_penalty_norm_p', type=int)
    parser_a2c.add_argument('--actor_weight_penalty_weight', type=float)
    parser_a2c.add_argument('--critic_weight_penalty_norm_p', type=int)
    parser_a2c.add_argument('--critic_weight_penalty_weight', type=float)

    parser_a2c.add_argument('--training_step_limit', type=int)

    args = parser.parse_args(argv)

    if args.rl_method is None:
        del args.rl_method

    return args


# ------


def save_checkpoint(ep, model: BasePolicy):
    checkpoint = {
        "episode": ep,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": model.optimizer.state_dict()
    }

    checkpoint_dir = Path(wandb.run.dir, "checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = Path(checkpoint_dir, f"{ep}.pt")
    latest_path = Path(checkpoint_dir, "latest.pt")
    torch.save(checkpoint, checkpoint_path)
    torch.save(checkpoint, latest_path)
    wandb.save("checkpoints/*.pt")  # sync saved files


# ------


def compute_penalty(z_values, ctx_values):
    p = CONFIG["encoder_activation_penalty_norm_p"]
    w = CONFIG["encoder_activation_penalty_weight"]
    encoder_activation_penalty = w * torch.sum(torch.linalg.vector_norm(z_values, ord=p, dim=-1))

    p = CONFIG["memory_activation_penalty_norm_p"]
    w = CONFIG["memory_activation_penalty_weight"]
    memory_activation_penalty = w * torch.sum(torch.linalg.vector_norm(ctx_values, ord=p, dim=-1))

    total_activation_penalty = encoder_activation_penalty + memory_activation_penalty
    info = {
        "encoder_activation_penalty": encoder_activation_penalty.detach().cpu().item(),
        "memory_activation_penalty": memory_activation_penalty.detach().cpu().item(),
    }
    return total_activation_penalty, info


def validation(valid_envs_dict: dict[str, gym.Env], policy: BasePolicy, n_episodes: int) -> None:
    for name, valid_env in valid_envs_dict.items():
        correct = []
        wrong = []
        timeout = []

        for _ in trange(n_episodes, desc="Validation", leave=False):
            trajectory = record_episode(valid_env, policy, explore=False)
            if trajectory.done_reason == "correct":
                correct.append(trajectory.total_reward)
            elif trajectory.done_reason == "wrong":
                wrong.append(trajectory.total_reward)
            elif trajectory.done_reason == "timeout":
                timeout.append(trajectory.total_reward)
            else:
                raise ValueError()

        tqdm.write(f"Step {policy.n_backprop_steps}, Env: {name}")
        tqdm.write(f"Correct/Wrong/Timeout")
        tqdm.write(f"Count: {len(correct)}/{len(wrong)}/{len(timeout)}")
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            tqdm.write(f" Mean: {np.array(correct).mean():0.3f}/{np.array(wrong).mean():0.3f}/{np.array(timeout).mean():0.3f}")
        tqdm.write("")

        valid_rewards = list(itertools.chain(correct, wrong, timeout))
        valid_reward_mean = sum(valid_rewards) / len(valid_rewards)

        wandb.define_metric(f"validation/{name}/reward_mean", step_metric="episode", summary="max")
        wandb.define_metric(f"validation/{name}/reward_histogram", step_metric="episode")

        wandb.log({
            f"validation/{name}/reward_mean": valid_reward_mean,
            f"validation/{name}/reward_histogram": wandb.Histogram(valid_rewards),
        }, commit=False)


def train_loop(model: BasePolicy, train_env: EpisodeCountWrapper, valid_envs_dict: dict[str, gym.Env]):
    # create replay buffer
    memory = TrajectoryReplayBuffer(capacity=model.REPLAY_BUFFER_CAPACITY)

    # specify collection strategy
    if model.ON_POLICY:
        collector = OnPolicyCollectionStrategy(memory)
    else:
        collector = OffPolicyCollectionStrategy(memory)

    losses: defaultdict[str, list[float]] = defaultdict(list)
    penalties: defaultdict[str, list[float]] = defaultdict(list[float])

    it: Iterable
    if CONFIG["training_step_limit"] > 1:
        it = range(CONFIG["training_step_limit"])
    else:
        it = itertools.count(0)  # uncapped

    for step in (pbar := tqdm(it, desc="Training", position=1)):
        # Validation
        data_logged = False
        if step % CONFIG["log_stats_freq"] == 0:
            wandb.log({
                "episode": train_env.episode_count,
                "backprop steps": model.n_backprop_steps,
            }, commit=False)

            validation(valid_envs_dict, model, 1000)

            for name, vals in losses.items():
                wandb.define_metric(f"training/loss/{name}", step_metric="episode")
                wandb.log({
                    f"training/loss/{name}": wandb.Histogram(vals),
                    # "penalty/encoder_activation": wandb.Histogram(encoder_activation_penalties),
                    # "penalty/memory_activation": wandb.Histogram(memory_activation_penalties),
                    # "penalty/dqn_weight ": wandb.Histogram(dqn_weight_penalties),
                }, commit=False)
                vals.clear()

            for name, vals in penalties.items():
                wandb.define_metric(f"training/penalty/{name}", step_metric="episode")
                wandb.log({
                    f"training/penalty/{name}": wandb.Histogram(vals),
                }, commit=False)
                vals.clear()

            data_logged = True

        if step % CONFIG["log_video_freq"] == 0 and CONFIG["log_videos"]:
            log_videos(valid_envs_dict, model, CONFIG["memory_type"])
            data_logged = True

        if step % CONFIG["log_checkpoint_freq"] == 0:
            save_checkpoint(memory.total_seen, model)
            data_logged = True

        if step % CONFIG["log_ratemap_freq"] == 0:
            log_ratemaps(valid_envs_dict, model, 500, CONFIG["log_individual_ratemaps"], CONFIG["log_aggregate_ratemaps"])
            data_logged = True

        if data_logged:
            wandb.log({}, commit=True)  # commit logs and advance step

        # Training
        if step % model.LEARN_BATCH_SIZE == 0:
            memory = collector.collect(train_env, model)

        loss_info, penalties_info = model.learn(memory)

        # Record loss and penalty
        for name, val in loss_info.items():
            losses[name].append(val)

        for name, val in penalties_info.items():
            penalties[name].append(val)

        if step % 10 == 0:
            total_loss = loss_info["total"]
            pbar.set_postfix(loss=f"{total_loss:.3f}")


def main(argv: Sequence[str] | None = None):
    global CONFIG

    args = parse_args(argv)

    if args.disable_tqdm:
        tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)

    wandb.init(project=args.wandb_project or WANDB_PROJECT,
               entity=args.wandb_entity or WANDB_ENTITY,
               group=args.wandb_group or WANDB_GROUP,
               mode=("disabled" if args.disable_wandb else None),
               settings=wandb.Settings(_disable_stats=True))

    exclude_keys = ["wandb_project", "wandb_entity", "wandb_group", "disable_wandb"]
    wandb.config.update(wandb.helper.parse_config(args, exclude=exclude_keys), allow_val_change=True)

    seed = wandb.config["seed"]
    if seed is None:
        seed = int(np.random.SeedSequence().generate_state(1))  # new random seed
    wandb.config.update({"seed": seed}, allow_val_change=True)

    CONFIG = wandb.config

    pprint(dict(CONFIG))
    print(flush=True)

    wandb.define_metric("episode", hidden=True)

    train_env, valid_envs_dict = init_envs(CONFIG)
    model = init_model(CONFIG, train_env)

    print(model.net)
    print(f"{model.ON_POLICY = }")
    print(f"{model.LEARN_BATCH_SIZE = }")
    print(f"{model.REPLAY_BUFFER_CAPACITY = }")
    print(flush=True)

    # wandb.watch([model], criterion, log="all", log_graph=False)

    try:
        train_loop(model, train_env, valid_envs_dict)
    except KeyboardInterrupt:
        # shutdown signal, exit gracefully and mark run as finished
        wandb.finish()
        exit(0)


if __name__ == '__main__':
    main()
