import math
import os

import numpy as np
import torch
import gymnasium as gym

from algorithms import RAINBOW_ALGO, NECSA_RAINBOW_ALGO, NECSA_DQN_ALGO, MPEC_ALGO


def get_save_best_and_memory_fn(log_path):

    def save_best_and_memory_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
        policy.save_memories(log_path)

    return save_best_and_memory_fn


def get_save_best_fn(log_path):

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

    return save_best_fn

def get_stop_fn(env, args):

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif "Pong" in args.task:
            return mean_rewards >= 20
        else:
            return False

    return stop_fn

def get_train_fn(policy, args, logger, buffer=None):

    def train_fn(epoch, env_step):

        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        if env_step % 1000 == 0:
            logger.write("train/env_step", env_step, {"train/eps": eps})

        if args.algo_name in [RAINBOW_ALGO, NECSA_RAINBOW_ALGO]:
            if not args.no_priority:
                assert buffer is not None

                if env_step <= args.beta_anneal_step:
                    beta = args.beta - env_step / args.beta_anneal_step * \
                           (args.beta - args.beta_final)
                else:
                    beta = args.beta_final
                buffer.set_beta(beta)
                if env_step % 1000 == 0:
                    logger.write("train/env_step", env_step, {"train/beta": beta})

    return train_fn


def get_test_fn(policy, args):

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    return test_fn


def get_save_checkpoint_fn(policy, log_path):

    def save_checkpoint_fn(epoch, env_step, gradient_step):
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
        ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
        torch.save({"model": policy.state_dict()}, ckpt_path)
        return ckpt_path

    return save_checkpoint_fn


def get_necsa_dict(args, env):

    is_action_discrete = isinstance(env.action_space, gym.spaces.Discrete)
    raw_state_dim_algos = [NECSA_DQN_ALGO, NECSA_RAINBOW_ALGO, MPEC_ALGO]

    if not args.reduction:
        if args.algo_name in raw_state_dim_algos:
            args.state_dim = args.raw_state_dim
        else:
            args.state_dim = env.observation_space.shape[0]

    NECSA_DICT = {
        'step': args.step,
        'grid_num': args.grid_num,
        'epsilon': args.epsilon,
        'necsa_lr': args.lr,
        'necsa_gamma': args.gamma,
        'mode': args.mode,
        'reduction': args.reduction,
        'circular_buffer': args.circular_buffer,
        'raw_state_dim': env.observation_space.shape[0] if args.algo_name not in raw_state_dim_algos else args.raw_state_dim,
        'state_dim': args.state_dim,
        'state_min': args.state_min,
        'state_max': args.state_max,
        'action_dim': env.action_space.shape[0] if not is_action_discrete else len(env.action_space.shape),
        'action_min': env.action_space.low[0] if not is_action_discrete else 0,
        'action_max': env.action_space.high[0] if not is_action_discrete else env.action_space.n,
        'necsa_advantage': args.necsa_adv,
        'score_type': args.score_type
    }
    print(NECSA_DICT)

    return NECSA_DICT


def get_mpec_dict(args, env):

    is_action_discrete = isinstance(env.action_space, gym.spaces.Discrete)
    raw_state_dim_algos = [NECSA_DQN_ALGO, NECSA_RAINBOW_ALGO, MPEC_ALGO]

    if not args.reduction:
        if args.algo_name in raw_state_dim_algos:
            args.state_dim = args.raw_state_dim
        else:
            set_state_dim(args, env)

    MPEC_DICT = {
        'step': args.step,
        'grid_num': args.grid_num,
        'lr': args.mpec_lr,
        'gamma': args.mpec_gamma,
        'mode': args.mode,
        'reduction': args.reduction,
        'circular_buffer': args.circular_buffer,
        'raw_state_dim': args.state_dim if args.algo_name not in raw_state_dim_algos else args.raw_state_dim,
        'state_dim': args.state_dim,
        'state_min': args.state_min,
        'state_max': args.state_max,
        'action_dim': env.action_space.shape[0] if not is_action_discrete else len(env.action_space.shape),
        'action_min': env.action_space.low[0] if not is_action_discrete else 0,
        'action_max': env.action_space.high[0] if not is_action_discrete else env.action_space.n,
        'dont_ask_for_policy': args.dont_ask_for_policy,
        'terminate_if_no_policy': args.terminate_if_no_policy,
        'policy_domination_decimal_places': args.policy_domination_decimal_places,
        'max_trajectory_length': args.max_trajectory_length,
        'debug': args.debug,
        'as1': args.as1,
        'as2': args.as2,
        'as3': args.as3,
        'as4': args.as4,
        'as5': args.as5,
        'as6': args.as6,
        'as7': args.as7,
        'as8': args.as8,
        'debug_naive_selection': args.debug_naive_selection,
        'debug_track_policies': args.debug_track_policies,
        'debug_disable_trajectory_length': args.debug_disable_trajectory_length,
        'debug_disable_average_reward': args.debug_disable_average_reward,
        'debug_learning_rate': args.debug_learning_rate,
        'debug_discount_factor': args.debug_discount_factor,
        'debug_track_trajectories_length': args.debug_track_trajectories_length,
        'debug_track_trajectories_split_and_mismatches': args.debug_track_trajectories_split_and_mismatches,
        'debug_disable_ssm': args.debug_disable_ssm,
        'debug_disable_cycle_detection': args.debug_disable_cycle_detection,
        'debug_disable_reconnection': args.debug_disable_reconnection,
    }
    print(MPEC_DICT)

    return MPEC_DICT

def set_state_dim(args, env):
    # args.state_dim = env.observation_space.shape[0]
    space = env.observation_space
    if isinstance(space, gym.spaces.Box):
        args.state_dim = space.shape[0]
    elif isinstance(space, gym.spaces.Discrete):
        args.state_dim = space.shape[0]
    elif isinstance(space, gym.spaces.Dict):
        # Option 1: Flatten the dict shape (you may need to define what "shape" means here)
        args.state_dim = sum({k: v.n if hasattr(v, 'n') else v.shape[0] for k, v in space.spaces.items()}.values())
        # Option 2: raise an error if not supported
        # raise NotImplementedError("Dict observation spaces are not supported by DiscreteEnvironments.")
    else:
        raise TypeError(f"Unsupported observation space type: {type(space)}")

def set_state_shape(args, env):
    # args.state_shape = env.observation_space.shape or env.observation_space.n
    space = env.observation_space
    if isinstance(space, gym.spaces.Box):
        args.state_shape = space.shape
    elif isinstance(space, gym.spaces.Discrete):
        args.state_shape = space.n
    elif isinstance(space, gym.spaces.Dict):
        # Option 1: Flatten the dict shape (you may need to define what "shape" means here)
        args.state_shape = sum({k: v.n if hasattr(v, 'n') else v.shape[0] for k, v in space.spaces.items()}.values())
        # Option 2: raise an error if not supported
        # raise NotImplementedError("Dict observation spaces are not supported by DiscreteEnvironments.")
    else:
        raise TypeError(f"Unsupported observation space type: {type(space)}")

def get_max_trajectory_length(gamma):
    if gamma == 1:
        return np.Inf

    count = 0
    old_discounted_length = 0
    discounted_length = 1
    while not math.isclose(old_discounted_length, discounted_length):
        old_discounted_length = discounted_length
        discounted_length = (discounted_length * gamma) + 1
        count += 1
    return count
