import torch.nn as nn
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
import sys
import numpy as np
from scipy.optimize import minimize
import io
import pickle
import sklearn
import math
from PIL import Image

from os import listdir
from os.path import isfile, join
import datetime


use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.final_classes = 2
        self.linear1 = nn.Linear(28*28*3, 128)
        self.linear2 = nn.Linear(128, 64)
        #self.linear3 = nn.Linear(128, 64)
        self.final = nn.Linear(64, self.final_classes)
        self.relu = nn.ReLU()
        self.sigm = nn.Sigmoid()

    def forward(self, img):
        x = img.view(-1, 28*28*3)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        #x = self.relu(self.linear3(x))
        x = self.final(x)
        return x


class ImagesCurve():

    def __init__(self, images):
        self.num_images = len(images)
        self.images = images

    def get_from_t(self, t):
        if t > 0 and t < self.num_images - 1:
            t_floor = int(t)
            t_ceil = t_floor + 1
            img1 = self.images[t_floor]
            img1_np = np.array(img1)
            dtype_img = img1_np.dtype
            max = np.iinfo(dtype_img).max

            img2 = self.images[t_ceil]
            img_ret = (((img1 * (t_ceil - t) + img2 * (t - t_floor))) * max).astype(np.uint8)
            img_ret = Image.fromarray(img_ret, 'RGB')
            return img_ret
        elif t == 0:
            return self.images[0]
        elif t == self.num_images - 1:
            return self.images[self.num_images - 1]
        else:
            raise ValueError("Value of t should be in between 0 and %s."%(self.num_images - 1))

    def get_total_distance(self):
        d = 0
        for t in range(self.num_images - 1):
            d += np.sqrt(np.sum((self.images[t] - self.images[t + 1])**2))
        return d

def duplicate_exists(list_to_check, value_to_check, boundary_thresh):
    found_duplicate = False
    for existing_x in list_to_check:
        if np.abs(existing_x - value_to_check) < boundary_thresh:
            return True
    return found_duplicate


def safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, breakpoint_x):
    if layer_number not in linear_boundaries:
        linear_boundaries[layer_number] = {}
    if neuron_number not in linear_boundaries[layer_number]:
        linear_boundaries[layer_number][neuron_number] = [breakpoint_x]
    else:
        if not duplicate_exists(linear_boundaries[layer_number][neuron_number], breakpoint_x, 0.001):
            linear_boundaries[layer_number][neuron_number].append(breakpoint_x)


def count_linear_regions(model_path, net, images_curve):
    if not net:
        net = torch.load(model_path)
        print("here")

    layers = [net.linear1, net.linear2, net.final]
    net.linear1.to(device)
    net.linear2.to(device)
    net.final.to(device)
    net.relu.to(device)
    layer_neurons = {0: 128, 1: 64, 2: 1}
    zero_threshold = 0.001
    linear_boundaries = {}

    test_transform = transforms.Compose(
        [transforms.CenterCrop(28),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    with torch.no_grad():
        for layer_number in range(len(layers) - 1):
            for neuron_number in range(layer_neurons[layer_number] - 1):
                num_init_points = 100
                #if np.random.rand() > 0.95:
                #    print("Evaluating for layer: %s, neuron: %s"%(layer_number, neuron_number))
                for i in range(1, num_init_points):
                    start_point = i*images_curve.num_images/num_init_points

                    def fun_to_optimize(t):
                        img = images_curve.get_from_t(t)
                        img = test_transform(img)
                        img = img.to(device)
                        img = img.float()
                        img = img.contiguous()
                        x = img.view(-1, 28 * 28 * 3)
                        for i in range(layer_number + 1):
                            if i == layer_number:
                                x = torch.abs(layers[i](x))
                            else:
                                x = net.relu(layers[i](x))
                        output = x.cpu().detach().numpy()
                        return output[0][neuron_number]

                    t_breakpoint = minimize(fun_to_optimize, np.array([start_point]), method='SLSQP', tol=1e-5, bounds=[(0, images_curve.num_images - 1)],
                                            options={'eps': 0.05, 'maxiter': 1000, 'disp': False})

                    if t_breakpoint.fun <= zero_threshold:
                        safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint.x)

    boundary_values = []
    for layer_number, layer_boundaries in linear_boundaries.items():
        for neuron_boundaries in layer_boundaries.values():

            for neuron_boundary in neuron_boundaries:
                for boundary_value in neuron_boundary:
                    if not duplicate_exists(boundary_values, boundary_value, 0.001):
                        boundary_values.append(boundary_value)

    boundary_values.sort()
    return len(boundary_values), boundary_values


def load_images_from_directory(dir_path):
    onlyfiles = [f for f in listdir(dir_path) if isfile(join(dir_path, f))]
    onlyfiles.sort()
    dir_images = []
    for f_name in onlyfiles:
        f_name_arr = f_name.split("-")
        f_idx = int(f_name_arr[1])
        if f_idx % 1 == 0:
            img = Image.open(join(dir_path, f_name))
            dir_images.append(img)
    return dir_images


def load_two_images_from_directory(dir_path):
    onlyfiles = [f for f in listdir(dir_path) if isfile(join(dir_path, f))]
    onlyfiles.sort()
    dir_images = []
    for f_name in onlyfiles:
        f_name_arr = f_name.split("-")
        f_idx = int(f_name_arr[1])
        if f_idx == 0 or f_idx == 100:
            img = Image.open(join(dir_path, f_name))
            dir_images.append(img)
    return dir_images


if __name__ == '__main__':
    model_dir = "cifar_models_epoch_1/"
    onlyfiles = [f for f in listdir(model_dir) if isfile(join(model_dir, f))]
    onlyfiles.sort()
    seed_val = sys.argv[1]
    dir1 = sys.argv[2]
    use_two_images = False
    if len(sys.argv) > 3:
        line_interpolate = sys.argv[3]
        if line_interpolate == "line_interpolate":
            use_two_images = True
    model_paths_dict = {}
    step_epochs = 0.1
    for f_name in onlyfiles:
        if seed_val in f_name:
            name_arr = f_name.split("_")
            epoch_number = round(float(name_arr[2]), 1)
            model_paths_dict[epoch_number] = f_name
    dir_name = dir1.split("/")[-1]
    if use_two_images:
        images_list = load_two_images_from_directory(dir1)
    else:
        images_list = load_images_from_directory(dir1)
    images_curve = ImagesCurve(images_list)
    print(datetime.datetime.now())
    for i in range(11):
        if i*step_epochs in model_paths_dict:
            model_name = model_paths_dict[i*step_epochs]
            model_path = join(model_dir, model_name)
            net = torch.load(model_path, map_location=device)
            net = net.float()
            net.to(device)
            print("Loaded model:%s"%(model_name))

            num_boundaries, boundary_values = count_linear_regions(model_path, net, images_curve)

            print("%s Epoch number: %s, Num Linear Boundaries: %s"%(datetime.datetime.now(), i*step_epochs, num_boundaries))
            if not use_two_images:
                with open("boundaries_data_200_cifar_epoch_1/" + seed_val + "_epoch_" + str(i*step_epochs) + "_" + dir_name +
                          "_.pkl", "wb") as f_out:
                    pickle.dump((num_boundaries, boundary_values), f_out)
            else:
                with open("boundaries_data_line_cifar_epoch_1/" + seed_val + "_epoch_" + str(i*step_epochs) + "_" + dir_name +
                          "_.pkl", "wb") as f_out:
                    pickle.dump((num_boundaries, boundary_values), f_out)


