import math
import numpy as np
import scipy
import torch
import random
from collections import defaultdict

def next_power_of_2(n):
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n += 1
    return n

def make_hadamard(size_m):
    if size_m == 1:
        return np.array([[1]])
    return scipy.linalg.hadamard(size_m)

def reorder(H):
    def sequency(row):
        return np.sum(row[:-1] != row[1:])
    return np.array(sorted(H, key=sequency))

def make_HLA(size_m, r, freq):
    if freq == "low":
        return reorder(make_hadamard(size_m))[:r, :]
    elif freq == "high":
        return reorder(make_hadamard(size_m))[-r:, :]

class Transform_Dict():
    def __init__(self):
        # defaultdict 대신 일반 dict를 사용하는 것이 실수를 방지하는 데 더 좋습니다.
        self.transform_dict = {}

    def get(self, name, size_m):
        key = f"{name}_{size_m}"
        return self.transform_dict.get(key) # .get()은 키가 없으면 None을 반환

    def register(self, name, size, r, freq="low"):
        size_m = next_power_of_2(size)
        key = f"{name}_{size_m}"

        if name == 'hadamard':
            self.transform_dict[key] = make_hadamard(size_m)
        elif name == 'low_rank':
            # 1. 누락된 'freq' 인자 추가
            self.transform_dict[key] = make_HLA(size_m, r, freq)
        else:
            raise NotImplementedError

    def get_or_register(self, name, size, r=16, freq="low"):
        # 2. 키 생성 로직을 일관되게 수정
        size_m = next_power_of_2(size)
        key = f"{name}_{size_m}"

        if key in self.transform_dict:
            return self.transform_dict[key]
        else:
            # register는 내부적으로 올바른 키를 사용해 저장합니다.
            self.register(name, size, r, freq)
            return self.transform_dict[key]