import yaml

import numpy as np
import GPy
from GPy.kern.src.stationary import Exponential
from sklearn.datasets import make_friedman1, load_diabetes, fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from model import GP
from utils import read_otherwise_generate

with open('configs.yaml', 'r') as f:
    configs = yaml.safe_load(f)

def get_data_and_model(dataset, scenario): 
    X, y = get_dataset(dataset)
    X, y, scaler_y = transform_dataset(dataset, X, y)
    gp_dict = get_model(dataset, X=X, y=y, scaler_y=scaler_y)
    gp = GP.from_dict(gp_dict)
    noise = get_homoscedastic_noise(y, scaler_y, after_std=gp.noise)
    split_random_state = get_split_random_state(dataset)
    X_train, X_test, y_train, y_test, noise_train, noise_test = train_test_split(X, y, noise, test_size=0.2, random_state=split_random_state)
    X_splits, y_splits, noise_splits = split_lists(X_train, y_train, noise_train, sizes=configs['partition'][dataset][scenario])
    return gp, X_splits, y_splits, noise_splits, X_test, y_test, noise_test

def get_dataset(dataset):
    return read_otherwise_generate('dataset', generate_dataset, dataset)

def get_model(dataset, **kwargs):
    gp_dict = read_otherwise_generate('model', generate_model, dataset, **kwargs)
    return gp_dict

def generate_dataset(dataset):
    if dataset == 'friedman':
        X, y = make_friedman1(n_samples=1000, n_features=6, noise=1)
        y = y.reshape(-1, 1) 
        return X, y
    if dataset == 'diap':
        X, y = load_diabetes(return_X_y=True)
        X = np.delete(X, [1], axis=1)
        y = y.reshape(-1, 1)
        return X, y
    if dataset == 'calih':
        X, y = fetch_california_housing(return_X_y=True)
        y = y.reshape(-1, 1)
        return X, y

def generate_model(dataset, X, y, scaler_y):
    if dataset == 'friedman':
        gp = GP(X, y, noise=get_homoscedastic_noise(y, scaler_y, before_std=1)[0])
        return gp.to_dict()
    if dataset == 'diap':
        num_features = X.shape[1]
        kernel_diap = Exponential(input_dim=num_features, ARD=True) + GPy.kern.RBF(input_dim=num_features, ARD=True)
        gp = GP(X, y, kernel=kernel_diap)
        return gp.to_dict()
    if dataset == 'calih':
        gp = GP(X[:2000], y[:2000])
        return gp.to_dict()

def transform_dataset(dataset, X, y):
    if dataset == 'friedman' or 'diap':
        scaler_X = StandardScaler(with_mean=False, with_std=False).fit(X)
        scaler_y = StandardScaler(with_mean=True, with_std=True).fit(y)
        X = scaler_X.transform(X)
        y = scaler_y.transform(y)
        return X, y, scaler_y
    if dataset == 'calih':
        scaler_X = StandardScaler(with_mean=True, with_std=True).fit(X)
        scaler_y = StandardScaler(with_mean=True, with_std=True).fit(y)
        X = scaler_X.transform(X)
        y = scaler_y.transform(y)
        return X, y, scaler_y

def get_homoscedastic_noise(y, scaler_y, before_std=1, after_std=None):
    variance = after_std if after_std else before_std / scaler_y.var_
    return np.full(y.shape, variance)

def get_split_random_state(dataset):
    return {'friedman': 12, 'diap': 6, 'calih': 18}[dataset]

def split_lists(*lists, sizes=[]):
    assert sum(sizes) <= len(lists[0])
    splits = [[] for _ in range(len(lists))]
    start_idx = 0
    for size in sizes:
        end_idx = start_idx + size
        for (id, lst) in enumerate(lists):
            splits[id].append(lst[start_idx:end_idx])
        start_idx = end_idx
    return splits
