import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from numpy.random import random_sample
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import pickle as pkl
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class Sphere2DSampler():

    def  __init__(self):
        #do nothing
        self.x = 0

    def get_from_t(self, t):
        return (np.sin(t), np.cos(t))

    def get_n_samples(self, n=1000, sample_range=(-3.135, 3.135)):
        xs = []
        ys = []
        all_samples = []
        t_samples = (sample_range[1] - sample_range[0]) * random_sample(size = n) + sample_range[0]

        for t_sample in t_samples:
            x, y = self.get_from_t(t_sample)
            xs.append(x)
            ys.append(y)
            all_samples.append([x, y])
        all_samples.sort()

        return xs, ys, t_samples, all_samples

    def get_distance(self, t1, t2):
        if t1*t2 > 0:
            return np.abs(t1 - t2)
        else:
            return np.minimum(np.abs(t1) + np.abs(t2), np.abs(t1 + t2))

    def get_interval_length(self, t1, t2):
        return 2*np.pi

class Net(nn.Module):

    def __init__(self, sample_mean, sample_std):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.layer_neurons = {0: 10, 1: 16, 2: 1}
        self.sample_mean = sample_mean
        self.sample_std = sample_std
        self.fc1 = nn.Linear(2, 10, bias=True)
        self.fc2 = nn.Linear(10, 16, bias=True)
        self.fcout = nn.Linear(16, 1, bias=True)
        self.layers = [self.fc1, self.fc2, self.fcout]

    def standardize_data(self, val_arr):
        val_arr = (val_arr - self.sample_mean) / self.sample_std
        return val_arr

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fcout(x)
        return x

    def get_layerwise_activation(self, layer_num, x):
        y = torch.Tensor(2)
        torch.cat([torch.tanh(x), 1 / torch.cosh(x)], out=y)
        x = y
        for i in range(layer_num + 1):
            if i == layer_num:
                x = torch.abs(self.layers[i](x))
            else:
                x = F.relu(self.layers[i](x))
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


def duplicate_exists(list_to_check, value_to_check, boundary_thresh):
    found_duplicate = False
    for existing_x in list_to_check:
        if np.abs(existing_x - value_to_check) < boundary_thresh:
            return True
    return found_duplicate


def safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, breakpoint_x):
    if layer_number not in linear_boundaries:
        linear_boundaries[layer_number] = {}
    if neuron_number not in linear_boundaries[layer_number]:
        linear_boundaries[layer_number][neuron_number] = [breakpoint_x]
    else:
        if not duplicate_exists(linear_boundaries[layer_number][neuron_number], breakpoint_x, 0.01):
            linear_boundaries[layer_number][neuron_number] .append(breakpoint_x)


def count_linear_regions(model_path, net=None, sample_range = (-5, 5)):
    if not net:
        net = torch.load(model_path)

    zero_threshold = 0.001
    linear_boundaries = {}

    for layer_number in range(len(net.layers) - 1):
        for neuron_number in range(net.layer_neurons[layer_number]):
            neuron_linear_regions = []
            #print(layer_number, neuron_number, net.layer_neurons[layer_number])
            #print(":::::::::::::::::::::::::::::::::::::::::::::::::::::::::")
            num_init_points = 10
            for i in range(1, num_init_points):

                start_point = sample_range[1] * i / num_init_points

                def fun_to_optimize(t):
                    x = torch.from_numpy(np.array(t)).float()
                    output = net.get_layerwise_activation(layer_number, x).detach().numpy()
                    return output[neuron_number]

                t_breakpoint = minimize(fun_to_optimize, np.array([start_point]), method='SLSQP', tol=1e-6, bounds=[(0, sample_range[1])],
                                        options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint.x)
                t_breakpoint_2 = minimize(fun_to_optimize, np.array([-1 * start_point]), method='SLSQP', tol=1e-6,
                                          bounds=[(sample_range[0], 0)],
                                          options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint_2.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint_2.x)
    boundary_values = []
    for layer_number, layer_boundaries in linear_boundaries.items():
        for neuron_boundaries in layer_boundaries.values():

            for neuron_boundary in neuron_boundaries:
                for boundary_value in neuron_boundary:
                    if not duplicate_exists(boundary_values, boundary_value, 0.0001):
                        boundary_values.append(boundary_value)

    boundary_values.sort()
    return len(boundary_values), boundary_values

def save_new_network(model_path):
    data_sampler = Sphere2DSampler()
    n_samples = data_sampler.get_n_samples(3000, (-5, 5))
    sample_mean = np.mean(n_samples)
    sample_std = np.std(n_samples)

    net = Net(sample_mean, sample_std)
    torch.save(net, model_path)

def generate_training_data(data_path, num_samples=200, periodic_freq=1.0, noise_scale = 0.25,
                           sample_range=(-5, 5)):
    sphere_2d_data_sampler = Sphere2DSampler()
    xs, ys, ts = sphere_2d_data_sampler.get_n_samples(num_samples, sample_range= sample_range)
    f_vals = []
    noise_vals = np.random.normal(scale=noise_scale, size=num_samples)
    for t_val, noise_val in zip(ts, noise_vals):
        f_vals.append(np.sin(t_val*np.pi/periodic_freq)*(1 + noise_val))

    ys = np.array([ys])
    xs = np.array([xs])
    f_vals = np.array([f_vals])

    data_array = np.concatenate((xs, ys, f_vals), axis=0)
    with open(data_path, 'wb') as data_f_out:
        pkl.dump(data_array, data_f_out)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_avg_distance(linear_boundaries, sample_range= (-5, 5), num_samples = 300):
    sphere_2d_sampler = Sphere2DSampler()
    _, _, sample_ts = sphere_2d_sampler.get_n_samples(num_samples, sample_range=sample_range)
    min_dists = []
    for sample_t in sample_ts:
        min_dist = np.inf
        for linear_boundary in linear_boundaries:
            sample_dist = sphere_2d_sampler.get_distance(sample_t, linear_boundary)
            if sample_dist < min_dist:
                min_dist = sample_dist
        min_dists.append(min_dist)
    return np.mean(np.array(min_dists))



def train_model(data_path, model_path, num_epochs = 60, sample_range = (-5, 5), run_number=0):
    net = torch.load(model_path)

    with open(data_path, 'rb') as f_in:
        all_data = pkl.load(f_in)
    fun_vals = np.reshape(all_data[2], (all_data[2].shape[0], 1))
    stacked_input = np.column_stack((all_data[0], all_data[1]))
    frac_train = int(0.8*stacked_input.shape[0])
    frac_test = stacked_input.shape[0] - frac_train
    print(frac_train, frac_test)
    train_inputs = stacked_input[:frac_train]
    test_inputs = stacked_input[-frac_test:]
    print(train_inputs.shape, test_inputs.shape)
    train_outputs = fun_vals[:frac_train]
    test_outputs = fun_vals[-frac_test:]
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    trainset = torch.utils.data.TensorDataset(torch.Tensor(train_inputs), torch.Tensor(train_outputs))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
    testset = torch.utils.data.TensorDataset(torch.Tensor(test_inputs), torch.Tensor(test_outputs))
    testloader = torch.utils.data.DataLoader(testset, batch_size=8,
                                             shuffle=False)
    num_linear_regions = []
    linear_region_epochs = []
    average_distances = []
    total_losses = []
    max_distance = Sphere2DSampler().get_interval_length(sample_range[0], sample_range[1])
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_loss = 0.0

        for i, data in enumerate(trainloader):
            inputs, real_fun_vals = data

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, real_fun_vals)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_loss += loss.item()
        print('[%d] train loss: %.3f' %
              (epoch + 1, epoch_loss / i))

        with torch.no_grad():
            total_loss = 0.0
            num_passes = 0
            for data in testloader:
                inputs, real_fun_vals = data
                outputs = net(inputs)
                loss = criterion(real_fun_vals, outputs)
                total_loss += loss.item()
                num_passes += 1
            print('[%d, %d] test loss: %.3f' % (run_number, epoch + 1, total_loss/num_passes))

        if epoch % 5 == 0:
            num_regions, linear_boundaries = count_linear_regions('', net, sample_range=sample_range)
            average_distance = get_avg_distance(linear_boundaries, sample_range=sample_range)
            average_distances.append(average_distance/max_distance)
            num_linear_regions.append(num_regions)
            total_losses.append(total_loss)
            linear_region_epochs.append(epoch)

    """plt.ylabel("Num Linear Regions Sphere")
    plt.plot(linear_region_epochs, num_linear_regions)
    plt.show()

    plt.ylabel("Average distance Sphere")
    plt.plot(linear_region_epochs, average_distances)
    plt.show()"""
    data_arr = [linear_region_epochs, num_linear_regions, average_distances, total_losses]
    with open("data/sphere_" + str(run_number) + ".pkl", "wb") as f_out:
        pkl.dump(data_arr, f_out)

def get_mean_and_variance():
    sphere_sampler = Sphere2DSampler()
    xs, ys, ts = sphere_sampler.get_n_samples(3000, sample_range=(-5, 5))
    print(np.mean(xs), np.mean(ys), np.std(xs), np.std(ys))

from scipy.interpolate import make_interp_spline, BSpline


if __name__ == '__main__':
    np.random.seed(12)

    """

    sample_range = (-3.135, 3.135)
    num_runs = 20
    for i in range(num_runs):
        generate_training_data('data/training_data_sphere.pkl', num_samples=1000, periodic_freq=1.0,
                               sample_range=sample_range)
        save_new_network('data/2d_model_sphere.torch')
        train_model('data/training_data_sphere.pkl', 'data/2d_model_sphere.torch', num_epochs=300, sample_range=sample_range, run_number=i)"""
    figure, axes = plt.subplots()
    Drawing_uncolored_circle = plt.Circle((0, 0),
                                          1.0,
                                          fill="blue",
                                          linewidth=0)
    #plt.grid()

    axes.set_aspect(1)
    axes.add_artist(Drawing_uncolored_circle)
    plt.xlim(-1.5, 1.5)
    plt.ylim(-1.5, 1.5)
    #plt.title('Circle')
    plt.savefig('filled-disc.png')



