import abc
from dataclasses import dataclass, asdict
from typing import Optional, Union, List

import torch

from constants import *
import numpy as np


class ParameterSchedule(abc.ABC):
    """
    A parameter with a value that changes over time according to a predefined schedule.
    """

    @abc.abstractmethod
    def get_value(self, it):
        """
        Get the value of this parameter at the given iteration.

        :param it: the iteration
        :return: the current value of this parameter
        """
        pass

    @abc.abstractmethod
    def __repr__(self):
        pass


class ConstantSchedule(ParameterSchedule):
    def __init__(self, value):
        self.value = value

    def get_value(self, it):
        return self.value

    def __repr__(self):
        return f"ConstantSchedule({str(self.__dict__)})"


class LinearSchedule(ParameterSchedule):
    def __init__(self, value_start, value_end, interpolation_start, interpolation_end):
        self.value_start = value_start
        self.value_end = value_end
        self.diff_value = value_end - value_start

        self.interpolation_start = interpolation_start
        self.interpolation_end = interpolation_end
        self.diff_steps = self.interpolation_end - self.interpolation_start

        self.normalized_diff = self.diff_value / self.diff_steps

        assert self.diff_steps > 0

    def get_value(self, it):
        if it < self.interpolation_start:
            return self.value_start
        elif it < self.interpolation_end:
            return self.value_start + (it - self.interpolation_start) * self.normalized_diff
        else:
            return self.value_end

    def __repr__(self):
        return f"LinearSchedule({str(self.__dict__)})"


class ExponentialDecaySchedule(ParameterSchedule):
    def __init__(self, value_max, value_min, decay_start, decay_coef):
        self.value_max = value_max
        self.value_min = value_min
        self.decay_start = decay_start
        self.decay_coef = decay_coef

    def get_value(self, it):
        if self.decay_start > 0 and it < self.decay_start:
            return self.value_max
        else:
            return max(self.value_min, self.value_max * (self.decay_coef ** (it - self.decay_start)))

    def __repr__(self):
        return f"ExponentialDecaySchedule({str(self.__dict__)})"


class SmoothTargetExponentialDecaySchedule(ParameterSchedule):
    def __init__(self, value_max, value_min, decay_start, decay_end, smoothness=0.01):
        self.value_max = value_max
        self.decay_start = decay_start
        self.smoothness = smoothness

        almost_zero = 0.001
        if value_min == 0:
            print(f"Warning: value_min = 0 is not supported. Using {almost_zero} instead.")
            self.value_min = almost_zero
        else:
            self.value_min = value_min

        # General form from eps decay:
        #   eps_end = eps_start * decay ^ (it_end - it_start)
        #   <==>  eps_end / eps_start = decay ^ (it_end - it_start)
        #   <==>  log_decay (eps_end / eps_start) = it_end - it_start
        #   <==>  log(eps_end / eps_start) / log(decay) = it_end - it_start
        #   <==>  log(eps_end / eps_start) / (it_end - it_start) = log(decay)
        #   <==>  decay = exp(log(eps_end / eps_start) / (it_end - it_start))
        #
        # We add an offset to get smooth curves
        #   eps_end = (eps_start - eps_end) * decay ^ (it_end - it_start) + eps_end
        #
        # However, the only solution to this equation would be decay = 0 as we never reach 0 with decay > 0
        # as an alternative, we add the "smoothness" parameter here.
        self.decay_coef = np.exp(
            np.log(smoothness / (self.value_max - self.value_min)) / (decay_end - decay_start)
        )

    def get_value(self, it):
        if self.decay_start > 0 and it < self.decay_start:
            return self.value_max
        else:
            return (self.value_max - self.value_min) * (self.decay_coef ** (it - self.decay_start)) \
                   + self.value_min

    def __repr__(self):
        return f"SmoothTargetExponentialDecaySchedule({str(self.__dict__)})"


@dataclass
class Config:
    """
    Configuration for a training run (including environment parameters, configparameters and debug parameters).
    """
    env_name: str = ""
    # run params

    log_eval_plots: bool = False
    # set a comment for the run (appended to tensorboard default dir)
    run_comment: str = ""
    # whether to log the runs to tensorboard
    enable_tensorboard: bool = True

    # schedules
    agent_epsilon_schedule: ParameterSchedule = None
    msg_size_epsilon_schedule: ParameterSchedule = None
    discrete_message_epsilon_schedule: ParameterSchedule = None
    entropy_schedule: ParameterSchedule = None

    # message params
    msg_sizes: tuple = (1, 4)
    message_mode: str = CONTINUOUS  # Available: CONTINUOUS, DRU, PSEUDOGRADIENT, DISCRETE
    receive_own_message: bool = False
    msg_decode_embedding_len: Optional[int] = None
    msg_decode_mode: str = AGGREGATE_MEAN

    # the size of the communication channel in slots (same unit as msg size)
    # used to simulate collisions while training
    comm_channel_size: Optional[int] = None
    comm_channel_type: Union[str, ChannelType] = ChannelType.Stochastic
    comm_channel_msg_size_spacing: bool = False
    # agents that are allowed to send, only used with the selective communication channel
    comm_channel_selective_allowed_agents: Optional[List[int]] = None

    msg_loss_weight: float = 0.5

    # whether to use softmax to select the message sizes instead of argmax
    softmax_msg_size_selection: bool = False
    # whether to force random msg sizes
    force_random_msg_size_selection: bool = False
    use_location_input: bool = True
    model_recurrent: bool = False
    model_with_policy_head: bool = False

    learning_rate: float = 0.001

    # training params

    # number of training iterations
    num_iterations: int = 200
    # each training iteration contains samples from num_envs * num_episodes episodes for each agent
    num_episodes: int = 1
    # number of steps between msg & state detach
    detach_gap: int = 10000
    # max gradient norm value used for clipping
    gradient_clip_max_norm: Optional[float] = None
    # how often to save the model. Only saves the model in the end when None.
    save_model_interval: Optional[int] = None
    save_model: bool = False
    # load model & run eval without logging anything
    load_model: Optional[str] = None

    # Whether to calculate & backpropagate agent Q losses only for the last step of the episode
    agent_q_loss_only_done: bool = True
    discount_factor: float = 1.0

    def __post_init__(self):
        self.check_msg_size(self.msg_sizes)

    @staticmethod
    def check_msg_size(msg_sizes):
        for i, s in enumerate(msg_sizes):
            assert s >= 0, f"Message sizes must be >= 0, found {s} in {msg_sizes}[{i}]."
            if s == 0:
                assert i == 0, f"Message size 0 is only allowed as first size. Found it in {msg_sizes}[{i}]."
            if i >= 1:
                assert s >= msg_sizes[i - 1], f"Message sizes must be monotonically increasing. " \
                                              f"Found {msg_sizes}[{i}] < {msg_sizes}[{i-1}]"

    def set_default_schedules(self):
        """
        Sets the default exploration schedules (if schedules are not set)
        """
        if self.msg_size_epsilon_schedule is None:
            self.msg_size_epsilon_schedule = SmoothTargetExponentialDecaySchedule(
               value_max=1,
               value_min=0.01,
               decay_start=0.2 * self.num_iterations,
               decay_end=0.6 * self.num_iterations,
               smoothness=0.01
            )

        if self.discrete_message_epsilon_schedule is None:
            self.discrete_message_epsilon_schedule = ConstantSchedule(0.01)

        if self.agent_epsilon_schedule is None:
            self.agent_epsilon_schedule = ConstantSchedule(0.01)

        if self.entropy_schedule is None:
            self.entropy_schedule = LinearSchedule(2.0, 0.1, 0, int(self.num_iterations * 0.7))

    def get_msg_decoding_mask_size(self):
        return len(self.msg_sizes)

    def to_str_dict(self):
        str_dict = asdict(self)
        for key in str_dict:
            str_dict[key] = str(str_dict[key])
        return str_dict


@dataclass
class ConfigPOMNIST(Config):
    """
    Configuration for a training run of POMNIST environment.
    """
    env_name: str = POMNIST
    agent_q_loss_only_done: bool = True

    # new params
    num_envs: int = 2048
    num_episodes = 1

    # Environment params
    x_splits: int = 0
    y_splits: int = 0
    max_steps: int = 2

    def get_num_agents(self):
        return (self.x_splits + 1) * (self.y_splits + 1)


@dataclass
class ConfigTrafficJunction(Config):
    """
    Configuration for a training run of Traffic Junction environment.
    """
    env_name: str = TRAFFIC_JUNCTION

    use_location_input: bool = False
    model_recurrent: bool = True
    model_with_policy_head: bool = True
    agent_q_loss_only_done: bool = False
    discount_factor: float = 1.0

    gradient_clip_max_norm: Optional[float] = 1.0

    # new params
    mode: str = 'easy'
    train_curriculum: bool = False
    num_envs: int = 256