from glob import glob
from ellipse_attack.transformations import Ellipse, Model
import numpy as np

model_paths = ["data/model/pythia-70m-deduped.npz", *sorted(glob("data/model/pythia-70m/step*.npz"))]
ellipses = [
        Model(**np.load(model_path)).ellipse()
        for model_path in model_paths
        ]
bias_ref = ellipses[2].bias[:ellipses[2].emb_size]
sort = bias_ref.argsort()

# with open("overleaf/data/pythia/rot_pred.dat", "w") as file:
#     for i, (true, pred) in enumerate(zip(rot2_true, rot2_pred)):
#         dot_product = np.dot(true, pred)
#         magnitude_x = np.linalg.norm(true)
#         magnitude_y = np.linalg.norm(pred)
#         cos_theta = dot_product / (magnitude_x * magnitude_y)
#         angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))
#         print(i, angle, sep="\t", file=file)

with open("overleaf/data/pythia/bias_pred.dat", "w") as file:
    ellipse_biases = [ellipse.bias[sort] for ellipse in ellipses]
    for i, biases in enumerate(zip(*ellipse_biases)):
        print(i, *biases, sep="\t", file=file)

with open("overleaf/data/pythia/stretch_pred.dat", "w") as file:
    ellipse_stretches = [ellipse.stretch for ellipse in ellipses]
    for i, stretches in enumerate(zip(*ellipse_stretches)):
        print(i, *stretches, sep="\t", file=file)
