import functools
import random

import numpy as np
import torch

from envs.new_env import TonesEnv
from envs.wrappers.episode_count_wrapper import EpisodeCountWrapper
from net.actor_critic import ActorCritic
from net.dqn import DQN
from net.encoder.conv import ConvEncoder
from net.encoder.dense import DenseEncoder
from net.encoder.identity import Encoder
from net.features_extractor import FeaturesExtractor
from net.memory.base import BaseMemory
from net.memory.rnn import RNNMemory, LSTMMemory, GRUMemory
from net.memory.sith import SITHSubSumOnlyMemory, SITHSubOnlyMemory, SITHSubMemory, SITHMemory
from policy.a2c import A2CPolicy
from policy.dqn import DQNPolicy


def init_envs(config, custom_valid_envs: dict | None = None):
    seed = config["seed"]

    env_kwargs = {
        'seq_len': config["env_seq_len"],
        'num_interval_on_left': config["env_left_range"],
        'num_interval_on_right': config["env_right_range"],
        'reward_turning_wrong_way_at_the_end': config["env_reward_wrong"],
        'reward_turning_correct_way_at_the_end': config["env_reward_correct"],
        'reward_going_backward': config["env_reward_backward"],
        'reward_turning_into_wall': config["env_reward_wall"],
        'reward_turning_into_back_wall': config["env_reward_backwall"],
        '_max_episode_steps': config["env_max_steps"],
        'pixel_output': config["env_pixel_output"],
        'pixel_output_shape': config["env_pixel_output_shape"],
        'decaying_walls': config["env_decaying_walls"],
        'decaying_rate_walls': config["env_decaying_rate_walls"],
        'flag_end_wall': config.get("env_flag_end_wall", True),
        'encode_obs_using_autoencoder': config.get("env_encode_obs_using_autoencoder", False)
    }

    # set random seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)

    # create training env
    train_env = TonesEnv(**env_kwargs)
    train_env.reset(seed=seed)
    train_env.action_space.seed(seed)
    train_env = EpisodeCountWrapper(train_env)

    # create validation envs
    valid_envs_dict = {}
    for idx, (name, override_kwargs) in enumerate(config["valid_envs"].items()):
        valid_env = TonesEnv(**(env_kwargs | override_kwargs))
        valid_env.reset(seed=seed+idx)
        valid_env.action_space.seed(seed+idx)
        valid_envs_dict[name] = EpisodeCountWrapper(valid_env)

    # add custom validation envs
    if custom_valid_envs is not None:
        for idx, (name, override_kwargs) in enumerate(custom_valid_envs.items()):
            valid_env = TonesEnv(**(env_kwargs | override_kwargs))
            valid_env.reset(seed=seed+idx)
            valid_env.action_space.seed(seed+idx)
            valid_envs_dict[name] = EpisodeCountWrapper(valid_env)

    return train_env, valid_envs_dict


def init_model(config, train_env):
    # select encoder class
    if config.get("env_encode_obs_using_autoencoder", False):
        encoder_class = DenseEncoder
    else:
        if not config["env_pixel_output"]:
            encoder_class = Encoder
        else:
            assert len(train_env.observation_space.shape) == 2
            encoder_type = config["encoder_type"]
            if encoder_type == "conv":
                encoder_class = ConvEncoder
            elif encoder_type == "dense":
                encoder_class = DenseEncoder
            else:
                raise ValueError("Unknown encoder type")

    # select memory class
    memory_class_map: dict[str, BaseMemory] = {
        "rnn": functools.partial(RNNMemory, output_size=config["memory_hidden_size"], nonlinearity="tanh"),
        "rnn_relu": functools.partial(RNNMemory, output_size=config["memory_hidden_size"], nonlinearity="relu"),
        "rnn_frozen": functools.partial(RNNMemory, output_size=config["memory_hidden_size"], nonlinearity="tanh", frozen_weights=True),
        "rnn_relu_frozen": functools.partial(RNNMemory, output_size=config["memory_hidden_size"], nonlinearity="relu", frozen_weights=True),
        "lstm": functools.partial(LSTMMemory, output_size=config["memory_hidden_size"]),
        "lstm_frozen": functools.partial(LSTMMemory, output_size=config["memory_hidden_size"], frozen_weights=True),
        "gru": functools.partial(GRUMemory, output_size=config["memory_hidden_size"]),
        "gru_frozen": functools.partial(GRUMemory, output_size=config["memory_hidden_size"], frozen_weights=True),
        "F": functools.partial(SITHMemory, use_F=True),
        "F_sub": functools.partial(SITHSubMemory, add_sum_neurons=False, use_F=True),
        "sith": SITHMemory,
        "sith_sub_sum": functools.partial(SITHSubMemory, add_sum_neurons=True),
        "sith_sub_nosum": functools.partial(SITHSubMemory, add_sum_neurons=False),
        "sith_subonly_sum": functools.partial(SITHSubOnlyMemory, add_sum_neurons=True),
        "sith_subonly_nosum": functools.partial(SITHSubOnlyMemory, add_sum_neurons=False),
        "sith_subonly_sumonly": SITHSubSumOnlyMemory
    }

    memory_class = memory_class_map[config["memory_type"]]

    features_extractor_kwargs = {
        'observation_space': train_env.observation_space,
        'latent_size': config["encoder_latent_size"],
        'add_z_skip': config["add_z_skip"],
        'add_outer': config["add_outer"],
        'encoder_activation_penalty_norm_p': config["encoder_activation_penalty_norm_p"],
        'encoder_activation_penalty_weight': config["encoder_activation_penalty_weight"],
        'memory_activation_penalty_norm_p': config["memory_activation_penalty_norm_p"],
        'memory_activation_penalty_weight': config["memory_activation_penalty_weight"],
        'encoder_class': encoder_class,
        'memory_class': memory_class
    }

    # instantiate RL model
    if config["rl_method"] == "dqn":
        net_kwargs = {
            'weight_penalty_norm_p': config['dqn_weight_penalty_norm_p'],
            'weight_penalty_weight': config['dqn_weight_penalty_weight'],
            'features_extractor_class': FeaturesExtractor,
            'features_extractor_kwargs': features_extractor_kwargs,
        }

        model = DQNPolicy(
            DQN,
            net_kwargs,
            gamma=config["gamma"],
            double=True,
            learn_batch_size=20000,
            sync_freq=100,
            action_space=train_env.action_space,
            learning_rate=config["learning_rate"]
        )
    elif config["rl_method"] == "a2c":
        net_kwargs = {
            'actor_weight_penalty_norm_p': config['actor_weight_penalty_norm_p'],
            'actor_weight_penalty_weight': config['actor_weight_penalty_weight'],
            'critic_weight_penalty_norm_p': config['critic_weight_penalty_norm_p'],
            'critic_weight_penalty_weight': config['critic_weight_penalty_weight'],
            'features_extractor_class': FeaturesExtractor,
            'features_extractor_kwargs': features_extractor_kwargs,
        }

        model = A2CPolicy(
            ActorCritic,
            net_kwargs,
            gamma=config["gamma"],
            action_space=train_env.action_space,
            learning_rate=config["learning_rate"]
        )
    else:
        raise ValueError("Unknown RL method")

    return model
