import numpy as np
import numpy
import matplotlib.pyplot as plt
from CausalDisco.analytics import r2coeff

from data_generation.data_from_dict import load_data

rng = np.random.default_rng()

data_config = {
    "file_path": None,
    "name": None,
    "id": None,
    "dictionary": {
        "X": {
            "type": "normal",
            "length": 1500,
        },
        "transformation": {
            "type": "linear",
            "args": {
                "num_hidden": 10,
                "num_parents": 1,
                # "alpha": 1,
            },
        },
        "shape": "one_mediator",
        "depth": 2,
        "seed": 42,
        "standardize": True,
        "noise_type": "normal",
        "mediator_noise_type": "normal",
        "mediator_noise_parameters": {
            "mean": 0,
            "std": numpy.sqrt(1),
        },
    },
}


def r2_sort(data):
    return [
        index
        for index, _ in sorted(
            enumerate(r2coeff(data.T)), key=lambda x: x[1], reverse=False
        )
    ]


datasets = []

max_attempts = 1000

attempt = 0
while attempt < max_attempts:
    data_config["dictionary"]["seed"] = attempt
    data = load_data(data_config)

    sorting = r2_sort(data)
    if attempt % 1000 == 0:
        print(f"Attempt {attempt}: {sorting}, r2 coeffs: {r2coeff(data.T)}")

    if sorting[0] == 1:
        datasets.append(data)
        # here our data is not r2 sortable
        print(f"Appended attempt {attempt}: {sorting}, r2 coeffs: {r2coeff(data.T)}")
    attempt += 1

print(f"Number of datasets: {len(datasets)}")
print(f"Percentage of datasets: {len(datasets) / max_attempts * 100:.2f}%")
