import numpy as np
import random
from utils_toy import DATA_DIR

import matplotlib.pyplot as plt
import matplotlib.cm as cm

def data_generate(h, b, mu=2, std=1, seed=0):
    np.random.seed((h + b) * (seed + 1))
    xs = np.random.normal(mu, std, size=h + b - 1)
    ys = np.random.normal(mu, std, size=h + b - 1)
    xs_h = xs[:h - 1]
    ys_h = ys[:h - 1]
    xs_b = xs[h - 1:]
    ys_b = ys[h - 1:]
    f = open(DATA_DIR, 'w', encoding='utf-8')
    for i in range(h - 1):  # content into txt
        f.writelines(str(xs_h[i]) + ' ' + str(ys_h[i]) + '\n')
    x = mu * h - sum(xs_h)
    y = mu * h - sum(ys_h)
    f.writelines(str(x) + ' ' + str(y) + '\n')
    for j in range(b):
        f.writelines(str(xs_b[j]) + ' ' + str(ys_b[j]) + '\n')
    f.close()
    print("toydata", (sum(xs_h) + x) / h, (sum(ys_h) + y) / h)
    return (sum(xs_h) + x) / h, (sum(ys_h) + y) / h


def read_txt(path):
    xs, ys = [], []
    with open(path, 'r') as f:
        for sample in f.readlines():
            x, y = sample.strip().split(" ")
            xs.append(float(x))
            ys.append(float(y))
    return xs, ys

def save_txt(data, path):
    size = len(data)
    f = open(path, 'w', encoding='utf-8')
    for i in range(size):
        f.writelines(data[i] + '\n')
    f.close()


class LossSurface:
    """A loss surface with L(x, y) = a * (x - mu1) ^2 + b * (y - mu2) ^2.
    """
    def __init__(self, a, b, mu1, mu2):
        self.a = a
        self.b = b
        self.mu1 = mu1
        self.mu2 = mu2

        N = 1000
        x_list = np.linspace(self.mu1 - 2.5, self.mu1 + 2, N)
        y_list = np.linspace(self.mu2 - 2.5, self.mu2 + 2, N)
        self.X, self.Y = np.meshgrid(x_list, y_list)
        self.Z = self.a * ((self.X - mu1) ** 2) + self.b * ((self.Y - mu2) ** 2)

    def plot(self):
        fig, ax = plt.subplots(figsize=(10, 10))
        # cmap = cm.get_cmap('Greens_r')
        cp = ax.contour(self.X, self.Y, self.Z, 50)
        # cbar = fig.colorbar(cp)
        ax.clabel(cp)
        ax.scatter([self.mu1] * 2, [self.mu2] * 2, s=[120, 20], color=['k','w'])

        ax.set_xlim(self.mu1 - 2.5, self.mu1 + 2)
        ax.set_ylim(self.mu2 - 2.5, self.mu2 + 2)

        ax.set_xlabel('x')
        ax.set_ylabel('y')

        return fig, ax


# if __name__ == '__main__':
#     a = 1 / 16
#     b = 9
#
#     loss_surface = LossSurface(a, b)
#     fig, ax = loss_surface.plot()
#     fig_name = 'loss_surface.png'
#     fig.savefig(fig_name)
#
#     print('{} saved.'.format(fig_name))

