import json
import math
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Optional, Union

from collectibles import ListCollection
from loguru import logger

from auto_encoder import debug
from auto_encoder.config_enums import (
    AutocastDtype,
    AutoEncoderType,
    CompensatingFeatureBiasPosition,
    PreprocessStrategy,
    ResamplingType,
    SchedulerType,
    TopMFlavour,
)


class EnumEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Enum):
            return {"__enum__": f"{obj.__class__.__name__}.{obj.name}"}
        return super().default(obj)


def round_to_nearest_multiple(x: int, base: int = 8):
    out = base * round(x / base)
    if out == 0:
        return base
    return out


@dataclass
class CausalGroupStructure:
    group_size: int
    num_groups: int
    num_active_groups: int

    # group_size 0 is a regular binary feature.
    # group_size 1 is a scalar feature which doesn't need to be centered at the origin
    # group_size >= 2 are multi-dimensional features which need not be centered at the origin


class CausalStructureCollection(ListCollection[CausalGroupStructure]):
    group_size: list[int]
    num_groups: list[int]
    num_active_groups: list[int]


@dataclass
class AutoEncoderConfig:
    autoencoder_type: AutoEncoderType = AutoEncoderType.TOPM
    transformer_model_name: str = "gpt2"

    save_steps: int = 2500

    # model hyperparameters
    num_neurons: int = 0
    feature_dim_multiplier: int = 8
    num_features: int = 0
    preprocess_strategy: PreprocessStrategy = PreprocessStrategy.SHIFT_AND_NORMALISE
    postprocess_bias: bool = True
    compensating_feature_bias_position: CompensatingFeatureBiasPosition = (
        CompensatingFeatureBiasPosition.BOTH
    )
    # ^Def not magnitude

    seq_len: int = 64

    # Soft Routing
    num_slots: int = 0

    # TopK / TopM
    stochastic_topk_temperature: float = 0.0
    min_topk_temperature: float = 0.7
    topk_gumbel_weighting: float = 0.1
    use_straight_through_estimator: bool = False
    topm_flavour: TopMFlavour = TopMFlavour.UNIFORM
    use_sinkhorn_assignment: bool = True

    topk: int = 0
    topk_frac: float = 8e-3

    binary_topk: int = 8
    binary_features_proportion: float = 0.5

    topm: int = 0
    topm_frac: float = 8e-3
    # topm_frac: float = 8e-3

    # JumpReLU
    jump_relu_kernel_bandwidth: float = 1e-3
    jump_threshold_offset: float = 0.9

    # training hyperparameters
    # batch_size: int = 24 if debug else 96
    # minibatch_size: int = 12 if debug else 48
    transformer_batch_size: int = 24 if debug else 96
    batch_size: int = 24 if debug else 1_536
    minibatch_size: int = 12 if debug else 512
    num_total_steps: int = 10_000  # For batch size 48, GDM do 80K+ steps, Anthropic do 1.6M

    learning_rate: float = 5e-4
    decoder_learning_rate_multiple: float = 1.0
    weight_decay: float = 0.0
    max_grad_norm: float = 5.0
    betas: tuple[float, float] = (0.9, 0.99)
    adam_eps: float = 7e-10
    schedule_free_lr_multiple: float = 1.0  # should be between 1x and 10x
    # feature_dropout: float = 0.01
    # ae_noise: float = 0.01

    # ema_multiplier: Optional[float] = 0.999
    ema_multiplier: Optional[float] = None

    # Efficiency related
    autocast_dtype_enum: AutocastDtype = AutocastDtype.NONE if debug else AutocastDtype.NONE
    # ^None just uses tf32 all the way and uses the decoder kernel resulting in blazingly fast decoding.
    # bf16 is 2x faster in general but avoids the decoder kernel so slightly slower there
    # fp16 allows for the decoder kernel (currently out of action).

    use_decoder_kernel: bool = False if debug else True
    use_schedule_free_adam: bool = False
    use_8bit_adam: bool = False
    # 8bit Adam causing issues when I reach inside to reset the optimizer states to 0...
    # Using the fused kernel instead

    use_loss_scaling: bool = False

    # Causal SAE related
    causal_structure_dict: dict[int, tuple[int, int]] = field(
        default_factory=lambda: {
            1: (12, 2),
            2: (6, 1),
            3: (4, 1),
            4: (3, 1),
        }
    )
    # group_size: (num_groups, num_active_groups)
    # Seems like an obvious case where topk isn't ideal, having a variable number of active groups should
    # give better performance but alas.
    # group_size 0 is a regular binary feature.
    # group_size 1 is a scalar feature which doesn't need to be centered at the origin
    # group_size >= 2 are multi-dimensional features which need not be centered at the origin

    # Loss function related
    # decorr_strength: float = 0.4

    auxiliary_l0_sparsity_coef: float = 1.0  # Used for JumpReLU approaches
    auxiliary_l1_sparsity_coef: float = 15.0
    auxiliary_l2_sparsity_coef: float = 1.0  # Used for FSQ

    # Gated SAE loss coefs
    auxiliary_gating_recon_coef: float = 1.0

    # Tripod loss coefs

    auxiliary_nfm_loss_coef: float = 1e-2
    auxiliary_nfm_inf_loss_coef: float = 3e-2
    auxiliary_multi_info_loss_coef: float = 0.0
    auxiliary_hessian_loss_coef: float = 0.0

    # VQ loss coefs
    auxiliary_codebook_loss_coef: float = 1.0
    auxiliary_commitment_loss_coef: float = 0.25

    # MoAE loss coefs
    auxiliary_balancing_loss_coef: float = 5e-3
    router_z_loss_coef: float = 2e-5
    expert_importance_loss_coef: float = 1e-3

    # PushThrough Loss Coefs
    feature_reconstruction_loss_coef: float = 1e-3

    finetune_kl_penalty_coef: float = 0.0

    # AuxK Loss Coefs
    auxiliary_dead_loss_coef: float = 1 / 32
    auxiliary_dying_loss_coef: float = 1 / 128

    # eval hyperparameters
    eval_steps: int = 20 if debug else 200
    eval_num_batches: int = 2 if debug else 5
    density_penalty_for_proxy_sweep_metric: float = 5.0

    # resampling hyperparameters
    resampling_type: ResamplingType = ResamplingType.NONE
    resampling_eps: float = 1e-3
    resample_steps: int = 0
    num_static_loss_samples: int = 0
    schedule_type: SchedulerType = SchedulerType.COSINE

    feature_activity_queue_length: int = 100
    feature_choice_likelihood: float = 0.0

    other_details: str = ""

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)

    def serialise_to_json(self) -> str:
        return json.dumps(asdict(self), cls=EnumEncoder)

    def __post_init__(self):
        # if self.router_config is None:
        #     self.router_config = RouterAutoEncoderConfig()
        if self.autoencoder_type == AutoEncoderType.ZIPF_TOPM:
            self.autoencoder_type = AutoEncoderType.TOPM
            self.topm_flavour = TopMFlavour.ZIPF

        if self.use_schedule_free_adam:
            self.learning_rate *= self.schedule_free_lr_multiple

        self.num_minibatches = self.batch_size // self.minibatch_size

        self.autocast_dtype = self.autocast_dtype_enum.dtype
        self.autocast_is_enabled = self.autocast_dtype_enum.is_enabled

        if self.resample_steps == 0:
            self.resample_steps = self.num_total_steps // 4

        if self.num_static_loss_samples == 0:
            self.num_static_loss_samples = 1 if debug else self.num_total_steps // 10

        if self.num_neurons == 0:
            model_hidden_sizes = {"custom": 512, "gpt2": 768, "gemma2": 3584}
            self.num_neurons = model_hidden_sizes[self.transformer_model_name]

        if self.num_features == 0:
            self.num_features = self.num_neurons * self.feature_dim_multiplier

        # if self.num_router_features == 0:
        #     self.num_router_features = 128 * self.feature_dim_multiplier

        self.topk_cutoff_prob: float = 1e-3 / self.num_features
        self.topm_cutoff_prob: float = 1e-3 / (self.seq_len * self.batch_size)

        if self.topk_frac > 0 and self.topk == 0:
            topk = math.ceil(self.num_features * self.topk_frac)
            self.topk = round_to_nearest_multiple(topk)
            # self.topk = topk

        if self.topm_frac > 0 and self.topm == 0:
            topk = math.ceil(self.num_features * self.topm_frac)
            # total_num_interactions = self.seq_len * self.minibatch_size * topk

            topm = math.ceil(topk * self.seq_len * self.minibatch_size / self.num_features)

            # topm = math.ceil(self.seq_len * self.minibatch_size * self.topm_frac)

            # self.topm = topm
            self.topm = round_to_nearest_multiple(topm, 4)
            # self.topm = topm
            logger.info(
                f"Set topm to {self.topm} based on topk frac {self.topk_frac}. Equivalent to topk = {topk}."
            )

        if self.autocast_dtype_enum.value == "bf16":
            # if self.use_8bit_adam:
            #     logger.warning("8bit Adam not supported with BF16, disabling.")
            #     self.use_8bit_adam = False
            if self.use_decoder_kernel:
                logger.warning("Decoder kernel not supported with BF16, disabling.")
                self.use_decoder_kernel = False

        self.causal_structure_collection = CausalStructureCollection(
            [
                CausalGroupStructure(group_size, num_groups, num_active_groups)
                for group_size, (
                    num_groups,
                    num_active_groups,
                ) in self.causal_structure_dict.items()
            ]
        )

        if self.autoencoder_type.generalised_topk:
            self.auxiliary_l1_sparsity_coef /= 50

        self.min_topk_temperature = min(
            self.min_topk_temperature, self.stochastic_topk_temperature
        )

    @classmethod
    def from_json(cls, json_string):
        data: dict = json.loads(json_string, object_hook=cls.decode_maybe_enum)
        if "decorr_strength" in data:
            data.pop("decorr_strength")
        if "num_router_neurons" in data:
            data.pop("num_router_neurons")
        if "num_router_features" in data:
            data.pop("num_router_features")
        return cls(**data)

    @staticmethod
    def decode_maybe_enum(d: dict[str, Any]):
        if "__enum__" in d:
            name, member = d["__enum__"].split(".")
            return getattr(globals()[name], member)
        else:
            return d


AEConfigBase = Union[AutoEncoderConfig]
