"""
Some fairly brute force methods to calculate a nice M and see what different Ms would look like.
Especially if we want to have some features which have the same prevalence and then can work in multiples of 4 or 8.
"""

import math
from collections import Counter
from dataclasses import dataclass

import torch as t
from loguru import logger


@dataclass
class MCounts:
    m_value: int
    count: int


def nearest_power_of_2(x: float) -> int:
    return 2 ** round(math.log2(x))


def calculate_m(
    m: int = 21,
    num_features: int = 4_096,
    num_tokens_per_batch: int = 24_576,
    zipf_bias: float = 7.0,
    zipf_exponent: float = 1.0,
) -> t.Tensor:
    """
    Calculate the M for a given k, num_features and num_tokens_per_batch.
    """
    out_m_F = t.empty(num_features, dtype=t.int32)

    num_interactions = m * num_features
    zipf_sum = sum([1 / (i + zipf_bias) ** zipf_exponent for i in range(1, num_features + 1)])

    N_approx = num_interactions / zipf_sum

    # Don't want any feature which hits too many of the tokens.
    largest_acceptable_power_of_2 = nearest_power_of_2(num_tokens_per_batch) / 4

    for i in range(num_features):
        m = nearest_power_of_2(N_approx / (i + 1 + zipf_bias) ** zipf_exponent)
        out_m_F[i] = min(m, largest_acceptable_power_of_2)

    logger.info(f"Total M: {out_m_F.sum().item()}")
    logger.info(f"Num interactions from original m={m}: {num_interactions}")

    # print(f"Head: {out_m_F[:10]}")
    # print(f"Tail: {out_m_F[-10:]}")

    return out_m_F


def get_m_counts(m_F: t.Tensor) -> list[MCounts]:
    num_features = len(m_F)

    values_list = m_F.tolist()
    value_counts = Counter(values_list)

    m_dict = dict(value_counts)

    m_counts = [MCounts(m_value=k, count=v) for k, v in m_dict.items()]

    assert sum(m_dict.values()) == num_features

    logger.info(f"{len(m_dict)} unique values for which to top-m over")
    logger.info(f"Highest: {max(m_dict)}, lowest: {min(m_dict)}")

    return m_counts


if __name__ == "__main__":
    m_F = calculate_m()
    m_counts = get_m_counts(m_F)
    print(m_counts)
