import numpy as np


class PasskLaw:
    def __init__(self) -> None:
        self.coeffs = None
        self.min_acc = None
        self.name = 'LogBilinear'

    def fit(self, flops, acc, ks, max_flops=None):
        x_fit = np.stack([np.log(np.array(flops)), np.log(np.array(ks)), np.log(np.array(flops)) * np.log(np.array(ks)), np.ones(flops.shape[0])], axis=1)
        y_fit = np.log(-np.log(np.array(acc)[:, None]))
        self.coeffs, _, _, _ = np.linalg.lstsq(x_fit, y_fit, rcond=None)

    def predict(self, flops, ks):
        assert self.coeffs is not None
        x_predict = np.stack([np.log(np.array(flops)), np.log(np.array(ks)), np.log(np.array(flops)) * np.log(np.array(ks)), np.ones(flops.shape[0])], axis=1)
        y_predict = x_predict @ self.coeffs
        accuracy = np.exp(-np.exp(y_predict))

        return accuracy[:, 0]
