import pickle
from pathlib import Path

import numpy as np
from tqdm import trange


def random_parameters():
    # Setting ranges for slope and intercept
    slope_range = (0.1, 0.9)
    intercept_range = (0, 0.5)

    # Random slope and intercept within the specified range
    random_slope = np.random.uniform(*slope_range)
    random_intercept = np.random.uniform(*intercept_range)
    return [random_slope, random_intercept]


def get_sample(n_points: int = 2, num_targets: int = 3, noise_scale: float = 0):
    # Generate a random mixture of cosines with noise
    n_points += num_targets  # output
    parameters = random_parameters()
    noise = np.random.normal(0, noise_scale, n_points)

    # Generate y values based on the random mixture
    x_values = [np.random.uniform(0, 1) for _ in range(n_points)]
    y_values = [
        parameters[0] * x + parameters[1] + noise[i] for i, x in enumerate(x_values)
    ]

    # Convert to strings
    x_values = [f"{float(v):.3f}" for v in x_values]
    y_values = [f"{float(v):.3f}" for v in y_values]
    hypos = list(zip(x_values, y_values))

    target_x_values = np.linspace(0, 1, num_targets)
    if num_targets == 1:
        target_x_values = [0.5]
    target_y_values = [parameters[0] * x + parameters[1] for x in target_x_values]
    target_x_values = [f"{float(v):.3f}" for v in target_x_values]
    target_y_values = [f"{float(v):.3f}" for v in target_y_values]
    targets = list(zip(target_x_values, target_y_values))
    # target: get three points from hypo
    xs = [x for x, _ in targets]
    xs = ", ".join(xs)
    ys = [y for _, y in targets]
    ys = ", ".join(ys)
    prompt = ", ".join([f"({x}, {y})" for x, y in hypos])
    prompt = (
        "Given the following data points, find the next point: "
        + prompt
        + ". Inputs for the next data points are: "
        + f"[{xs}]"
    )
    # response = ", ".join([f"({x}, {y})" for x, y in targets])
    response = f"[{ys}]"
    conv = [prompt, response]

    row = {
        "parameters": [float(p) for p in parameters],
        "conv": conv,
    }
    return row


def _build_data(
    n_samples: int = 100,
    n_points: int = 2,
    num_targets: int = 3,
):
    # Generate data for the training and test sets
    n_eval = 100
    data = [
        get_sample(n_points=n_points, num_targets=num_targets)
        for _ in trange(n_samples + n_eval, desc="building data")
    ]
    train, test = data[:-n_eval], data[-n_eval:]
    subsets = {
        **{f"train{i}": train[:i] for i in range(1, len(train) + 1)},
        "train": train,
        "test": test,
    }
    return subsets


def build_data(
    n_samples: int = 100,
    n_points: int = 2,
    num_targets: int = 1,
    seed: int = 42,
    path=".",
):
    filename = f"linear_samples_{n_samples}_points_{n_points}_targets_{num_targets}_seed_{seed}"
    # Load data if it exists
    _path = Path(path) / f"{filename}.pkl"
    if _path.is_file():
        data = pickle.load(open(_path, "rb"))
    else:
        _path.parent.mkdir(parents=True, exist_ok=True)
        np.random.seed(seed)
        data = _build_data(n_samples, n_points=n_points, num_targets=num_targets)
        pickle.dump(data, open(_path, "wb"))
    return data


if __name__ == "__main__":
    path = "../../../../data/toycos/data"
    data = build_data(path=path)
    import ipdb; ipdb.set_trace()  # noqa # fmt: skip
