
import numpy as np


import torch
import torch.cuda
import torch.nn as nn
import torch.optim as optim

import numpy as np
from scipy import spatial


def get_features(x, class_size=10, num_class=6, option='weights'):
    class_features = []
    for i in range(num_class):
        class_x = x[i * class_size:(i + 1) * class_size]
        class_mean = (torch.mean(class_x, 0)).cpu().detach().numpy()
        class_features.append(class_mean)

    avg_feature = np.mean(class_features, axis=0)

    centralized_features = x.cpu().detach().numpy() - avg_feature
    class_features = np.array(class_features) - avg_feature
    D = np.mean(np.linalg.norm(class_features, axis=1))

    with_in = 0
    for i in range(num_class):
        for j in range(class_size):
            with_in += np.linalg.norm(centralized_features[i * class_size + j] - class_features[i]) / len(x)
    return class_features, with_in / D


def analyze_collapse_new(linear_weights, option='weights'):
    num_classes = len(linear_weights)
    weight_norm = [np.linalg.norm(linear_weights[i]) for i in range(num_classes)]
    cos_matrix = np.zeros((num_classes, num_classes))
    between_class_cos = []
    for i in range(num_classes):
        for j in range(num_classes):
            cos_value = 1 - spatial.distance.cosine(linear_weights[i], linear_weights[j])
            cos_matrix[i, j] = cos_value
            if i != j:
                between_class_cos.append(cos_value)
    weight_norm = np.array(weight_norm)
    # print('{0} avg square norm'.format(option), np.mean(np.square(weight_norm)))
    between_class_cos = np.array(between_class_cos)

    return np.std(weight_norm) / np.mean(weight_norm), np.mean(
        np.abs(between_class_cos + 1 / (num_classes - 1))), np.max(np.abs(between_class_cos + 1 / (num_classes - 1)))


def analyze_dual(linear_weights, class_features):
    n_class = len(class_features)
    linear_weights = linear_weights[:n_class]
    linear_weights = linear_weights / np.linalg.norm(linear_weights)
    class_features = class_features / np.linalg.norm(class_features)

    return np.square(np.linalg.norm(linear_weights - class_features))


seed = 10
torch.manual_seed(seed)

class Model(nn.Module):
    def __init__(self, layer, width, num_class, class_size):
        super(Model, self).__init__()
        self.width = width
        self.num_class = num_class
        self.class_size = class_size
        self.layer = layer
        self.x = nn.Parameter(torch.randn([class_size * num_class, width]), requires_grad=True)
        self.label = []
        for i in range(num_class):
            for j in range(class_size):
                self.label.append(i)
        self.label = torch.Tensor(self.label)
        self.lastlayer = nn.Linear(width, num_class, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.lastlayer(x)
        return x


def cosine(x, num_class, class_size):
    cos_list = torch.zeros([num_class, num_class])
    mean_list = []
    for i in range(num_class):
        class_x = x[i * class_size:(i + 1) * class_size]
        class_mean = torch.mean(class_x, 0)
        mean_list.append(class_mean)
        for j in range(len(mean_list)):
            cos_list[j][i] = (torch.dot(mean_list[j], class_mean) / (
                        torch.norm(class_mean) * torch.norm(mean_list[j]))).item()
    return cos_list


def variation(x, num_class, class_size):
    var = 0
    for i in range(num_class):
        class_x = x[i * class_size:(i + 1) * class_size]
        class_var = torch.var(class_x, 0)
        var += torch.sum(class_var)
    return var.item()

# parameters: number of hidden layers (fixed to be 1 in LPM), width of hidden layers, number of classes, sample size in each class
def train(layer, width, num_class, class_size):
    Epoch = 100000
    K = 100
    w_norm_variation = [0] * ((Epoch // K) + 1)
    w_cos_mean = [0] * ((Epoch // K) + 1)
    w_cos_max = [0] * ((Epoch // K) + 1)
    h_norm_variation = [0] * ((Epoch // K) + 1)
    h_cos_mean = [0] * ((Epoch // K) + 1)
    h_cos_max = [0] * ((Epoch // K) + 1)
    dual = [0] * ((Epoch // K) + 1)
    with_in = [0] * ((Epoch // K) + 1)
    model = Model(layer, width, num_class, class_size)
    x_data = model.x
    y_data = model.label.long()
    CE = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=5)
    iter = 0
    loss_list = []
    index = 0
    while iter < Epoch:
        output = model(x_data)
        loss = CE(output, y_data)
        ## L-2 regularization
        # l2_lambda = 1e-7
        # l2_reg = torch.tensor(0.)
        # for param in model.parameters():
        #  l2_reg += torch.norm(param)**2
        # loss += l2_lambda * l2_reg
        ##
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        error = loss.item()
        loss_list.append(error)
        if iter < 10000 and iter % 1000 == 0:
            print(iter)
            print(cosine(x_data, num_class, class_size))
        if iter < 100000 and iter % 10000 == 0:
            print(iter)
            print(cosine(x_data, num_class, class_size))
        if iter < 1000000 and iter % 100000 == 0:
            print(iter)
            print(cosine(x_data, num_class, class_size))
        if iter % 1000000 == 0:
            print(iter)
            print(cosine(x_data, num_class, class_size))
        # print(model.x[2*3+1])
        # print(model.lastlayer.weight[3])
        if (iter < 100 and iter % 2 == 1) or (iter > 100 and iter < 1000 and iter % 100 == 1) or iter % 1000 == 1:
            linear_weights = model.lastlayer.weight.cpu().data.numpy()
            w_norm_variation[index], w_cos_mean[index], w_cos_max[index] = analyze_collapse_new(
                linear_weights, option='weights')
            class_features, with_in[index] = get_features(x_data, class_size=class_size, num_class=num_class)
            h_norm_variation[index], h_cos_mean[index], h_cos_max[index] = analyze_collapse_new(
                class_features, option='features')
            dual[index] = analyze_dual(linear_weights, class_features)
            print('w norm:', w_norm_variation[index], 'w cos:', w_cos_mean[index])
            print('h norm:', h_norm_variation[index], 'h cos:', h_cos_mean[index])
            print('with_in:', with_in[index])
            print('dual:', dual[index])
            print('epoch:', iter, 'loss:', loss)
            index += 1
        iter = iter + 1

    cos_data = cosine(x_data, num_class, class_size)
    print(cos_data)
    cos_out = cosine(output, num_class, class_size)
    cos_weight = cosine(model.lastlayer.weight, num_class, 1)
    print(cos_weight)
    # print(torch.svd(model.lastlayer.weight)[1])

    return w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in

if __name__ == '__main__':
    # per-epoch values
    # w_norm_variation: norm of classifier
    # w_cos_mean: cosine of classifier
    # h_norm_variation: norm of last-layer features
    # h_cos_mean: cosine of last-layer features
    # dual: distance between normalized classifier and normalized centered last-layer features
    # with_in: with-in class variations of last-layer features
    w_norm_variation, w_cos_mean, h_norm_variation, h_cos_mean, dual, with_in = train(1, 20, 5, 10)