from models.utils import get_online_offline_data_dict, get_online_offline_query_dict
import sys
from bisect import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import ray

project_root = Path(__file__).parent.parent.parent.resolve()
sys.path.append(str((project_root / "src").resolve()))


def generate_circles_data(
    seed, size_datasets, T,
    circle_centres: List[(Tuple[float, float])] = [
        (0.2, 0.5),
        (0.22, 0.5),
        (0.24, 0.5),
        (0.26, 0.5),
        (0.28, 0.5),
        (0.3, 0.5),
        (0.32, 0.5),
        (0.34, 0.5),
    ],
    circle_radii: List[float] = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
    drifts: List[int] = [2,4,6,8,10,12,14],  # [15, 45, 85],
    gradual_batches: int = 2,
):
    """Creates data for Circles dataset. Classification boundary is circles and changes gradually.

    Args:
        num_batches (int, optional): The number of batches of data. Defaults to 100.
        points_per_batch (int, optional): The number of points in each batch. Defaults to 1000.
        seed (int, optional): The seed for the random number generator. Defaults to 0.
        circle_centres (List[, optional): The centres for the decision boundary. Defaults to [ (0.2, 0.5), (0.4, 0.5), (0.6, 0.5), (0.8, 0.5), ].
        circle_radii (List[float], optional): The radius of circles for decision boundary. Defaults to [0.15, 0.2, 0.25, 0.3].
        drifts (List[int], optional): The batches when drifts occur. Defaults to [15, 45, 85].
        gradual_batches (int, optional): The number of batches over which concept drift occurs. Defaults to 5.

    Returns:
        X,y: The points and labels
    """
    rng = np.random.default_rng(seed=seed)
    X = []
    for _ in range(T):
        # sample random points
        _X = rng.random(size=(size_datasets, 2))
        X.append(_X)
    y = classify_data(X,
                      seed=seed,
                      circle_centres=circle_centres,
                      circle_radii=circle_radii,
                      drifts=drifts,
                      gradual_batches=gradual_batches)
    return X, y


def classify_data(
    X: List[np.ndarray],
    seed: int ,
    circle_centres:List[(Tuple[float, float])],
    circle_radii: List[float],
    drifts: List[int],
    gradual_batches:int,
    **kwargs
) -> List[np.ndarray]:
    rng = np.random.default_rng(seed=seed)
    y = []
    drift = False
    gradual_batches_count = 0
    for t, _X in enumerate(X):
        circle_idx = bisect(drifts, t)
        c1, c2 = circle_centres[circle_idx]
        r = circle_radii[circle_idx]
        if t in drifts:
            drift = True
            gradual_batches_count = 1

        if drift and gradual_batches_count < gradual_batches:
            _points_per_batch = _X.shape[0]
            # sample more points for each gradual batch
            num_samples = int(
                (gradual_batches_count / gradual_batches) * _points_per_batch)
            idx = rng.choice(range(_points_per_batch),
                             size=num_samples,
                             replace=False)
            _c1, _c2 = circle_centres[circle_idx - 1]
            _r = circle_radii[circle_idx - 1]
            # label all according to previous concept
            _y = ((_X[:, 0] - _c1)**2 +
                  (_X[:, 1] - _c2)**2 <= _r**2).astype(int)
            # replace the sampled data with the new concept
            _y[idx] = ((_X[idx, 0] - c1)**2 +
                       (_X[idx, 1] - c2)**2 <= r**2).astype(int)
            gradual_batches_count += 1
            # if gradual_batches_count == gradual_batches:
            # drift = False
            # print(f"\tChanged {num_samples} samples {_c1=} {_c2=} {_r=}")

        else:
            # print("\tStandard classification")
            _y = ((_X[:, 0] - c1)**2 + (_X[:, 1] - c2)**2 <= r**2).astype(int)

        y.append(_y)

    return y


def plot_data(
    data_dict: Dict[int, dict],
    query_dict: Optional[Dict[int, dict]] = None,
    circle_centres: List[(Tuple[float, float])] = [
        (0.2, 0.5),
        (0.4, 0.5),
        (0.6, 0.5),
        (0.8, 0.5),
    ],
    circle_radii: List[float] = [0.15, 0.2, 0.25, 0.3],
    **kwargs
):
    """Plots the data assuming a Circles concept

    Args:
        data_dict (Dict[int,dict]): the data dictionary. keys are batches and values are dicts.
        query_dict (Optional[Dict[int,dict]], optional): Query dictionary to plot query points. If none uses samples from data as queries. Defaults to None.
        circle_centres (List[, optional): The centres for the decision boundary. Defaults to [ (0.2, 0.5), (0.4, 0.5), (0.6, 0.5), (0.8, 0.5), ].
        circle_radii (List[float], optional): The radius of circles for decision boundary. Defaults to [0.15, 0.2, 0.25, 0.3].
    """

    T = len(data_dict)
    num_rows = ((T + 1) // 4) + 1
    _, ax = plt.subplots(num_rows, 4, figsize=(5 * 4, 5 * num_rows))
    ax = ax.flatten()
    colors = {0: "yellow", 1: "purple"}
    color_map = np.vectorize(colors.get)

    for t in range(T):
        data = data_dict[t]
        if isinstance(data, ray._raylet.ObjectRef):
            data = ray.get(data)
        X = data["X_train"]
        y = data["y_train"]

        ax[t].scatter(X[:, 0], X[:, 1], c=color_map(y))
        ax[t].set_xlim(-0.5, 1)
        ax[t].set_ylim(-0.5, 1)
        if query_dict is not None:
            query = query_dict[t]
            if isinstance(query, ray._raylet.ObjectRef):
                query = ray.get(query)
            X_q = query["X_query"]
        else:
            X_q = data["X_query"]
        ax[t].scatter(X_q[:, 0], X_q[:, 1], ec="red", c=None)
        ax[t].set_title(f"{t=}")
        for xy, r in zip(circle_centres, circle_radii):
            c = plt.Circle(xy, r, fill=False)
            ax[t].add_patch(c)
    plt.show()


