
import ray
import sys
from pathlib import Path
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import numpy as np


project_root = Path(__file__).parent.parent.parent.resolve()
sys.path.append(str((project_root/"src").resolve()))


def generate_gauss_data(seed: int, size_datasets:int, T:int,  alpha: float = 1.0, std: float = 0.1):
    rng = np.random.default_rng(seed=seed)
    X = []
    for t in range(T):
        # Covariate shift of the region of data
        # this can be used to create a periodic effect ((t + 1) % 15) / 30 + cov_offset
        loc = ((t + 1)) / 100 
        centre = [loc, 0.5-loc]
        _X = rng.normal(loc=centre, scale=std, size=(size_datasets, 2))
        X.append(_X)
    y = classify_data(X, alpha=alpha)
    return X, y


def classify_data(X: List[np.ndarray], **kwargs):
    y = []
    for t, _X in enumerate(X):
        _y = (4*(_X[:, 0]-0.5)**2 > _X[:, 1]).astype(int)
        y.append(_y)
    return y


def plot_data(data_dict: Dict[int, dict], query_dict: Optional[Dict[int, dict]] = None, alpha: float = 1, **kwargs):
    """Plots the data assuming a gauss concept

    Args:
        data_dict (Dict[int,dict]): the data dictionary. keys are batches and values are dicts.
        query_dict (Optional[Dict[int,dict]], optional): Query dictionary to plot query points. If none uses samples from data as queries. Defaults to None.
    """

    T = len(data_dict)
    num_rows = ((T+1)//4)+1
    _, ax = plt.subplots(num_rows, 4, figsize=(5*4, 5*num_rows))
    ax = ax.flatten()
    x = np.linspace(-0.5, 1)
    line = 4*(x-0.5)**2
    colors = {0: "yellow", 1: "purple"}
    color_map = np.vectorize(colors.get)
    for t in range(T):
        data = data_dict[t]
        if isinstance(data, ray._raylet.ObjectRef):
            data = ray.get(data)
        X = data["X_train"]
        y = data["y_train"]
        ax[t].scatter(X[:, 0], X[:, 1], c=color_map(y))
        ax[t].plot(x, line, "b--")
        ax[t].set_xlim(-0.5, 1)
        ax[t].set_ylim(-0.5, 1)
        if query_dict is not None:
            query = query_dict[t]
            if isinstance(query, ray._raylet.ObjectRef):
                query = ray.get(query)
            X_q = query["X_query"]
        else:
            X_q = data["X_query"]
        ax[t].scatter(X_q[:, 0], X_q[:, 1], ec="red", c=None)
        ax[t].set_title(f"{t=}")
    plt.show()




