import io
import numpy as np

import PIL.Image
import matplotlib.pyplot as plt
from pybnn.bohamiann import Bohamiann
from torchvision.transforms import ToTensor
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score

name_model = ""
log = None
working_dir = None
batch_idx = 0
writer = None
is_training = True
current_best = 1
model_cpu = None

def create_scatter_plot(vtrue, vpred, title=""):
    """Create a pyplot plot and save to buffer."""
    plt.figure()
    plt.plot(1 - np.array(vtrue), 1 - np.array(vpred), 'o')
    plt.plot([0, 0], [1, 1], color='k', linestyle='-', linewidth=2)

    plt.xlabel("True")
    plt.ylabel("Pred")
    plt.xlim((0, 1))
    plt.ylim((0, 1))
    plt.title(title)
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    buf.seek(0)

    plot_buf = buf
    image = PIL.Image.open(plot_buf)
    image = ToTensor()(image)
    return image


def add_scalars(name, v_dida, v_baseline):
    global batch_idx
    writer.add_scalars(name, {"dida": v_dida, "baseline": v_baseline}, batch_idx)

def add_scalar(name, v_dida):
    global batch_idx
    writer.add_scalar(name, v_dida, batch_idx)

def debug_tensor(name, t):
    # global writer, batch_idx, is_training
    # if is_training:
    #     writer.add_scalar("{}/max".format(name), t.max(), batch_idx)
    #     writer.add_scalar("{}/min".format(name), t.min(), batch_idx)
    #     writer.add_histogram("{}".format(name), t, batch_idx)
    pass

def compute_score_rf(X_train, y_train, X_test, y_test, return_pred = False, return_model=False):
    model = RandomForestRegressor(n_jobs=-1) # default parameters
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    to_return = []
    to_return.append(mean_squared_error(y_true=y_test, y_pred=y_pred))
    if return_pred: to_return.append(y_pred)
    if return_model: to_return.append(model)

    return to_return


def compute_score_bohamiann(X_train, y_train, X_test, y_test, return_pred = False, return_model=False):
    model = Bohamiann(print_every_n_steps=1000, normalize_output=False, normalize_input=False)
    model.train(X_train, y_train, num_steps=20000, num_burn_in_steps=2000, keep_every=50, lr=1e-2, verbose=True)

    y_pred, _ = model.predict(X_test)
    to_return = []
    to_return.append(mean_squared_error(y_true=y_test, y_pred=y_pred))
    if return_pred: to_return.append(y_pred)
    if return_model: to_return.append(model)

    return to_return
