import torch
import torch.nn as nn
import torch.optim as optim
from torch import linalg as LA

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from utils import *


dc = 32
ds = 16
id = 256
hd = 128
od = 1
linear = False
freeze_2nd = True
classification = True
print_freq = 10
hinge = True
plot_corr_sum = True
plot_corr_sum_bg = True
plot_random_neurons = True

optim_name = 'SGD'
# optim_name = 'Adam'
# optim_name = 'AdamW'

N = 10000
GID = 0
ITERS = 20000

LR = 1e-3
WD = 0.001


class TwoLayerReLUNet(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, bias=False):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)

        if classification:
            self.fc1b = nn.Linear(input_dim, hidden_dim, bias=bias)

        if linear:
            self.relu = nn.Identity()
        else:
            self.relu = nn.ReLU()
            # self.relu = nn.Sigmoid()
            # self.relu = nn.Tanh()
            # self.relu = nn.GELU()
            # self.relu = nn.LeakyReLU(negative_slope=1.0)
        
        if freeze_2nd:
            # self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
            # self.fc2.weight.data = 0.
            self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
            if hinge:
                self.fc2.weight.data = 1.0 / hidden_dim * torch.ones(output_dim, hidden_dim)
            else:
                self.fc2.weight.data = torch.ones(output_dim, hidden_dim)
            if classification:
                self.fc2b = nn.Linear(hidden_dim, output_dim, bias=bias)
                if hinge:
                    self.fc2b.weight.data = -1.0 / hidden_dim * torch.ones(output_dim, hidden_dim)
        else:
            self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
            if classification:
                self.fc2b = nn.Linear(hidden_dim, output_dim, bias=bias)
    
    def forward(self, x):
        y = self.fc1(x)
        y = self.relu(y)
        y = self.fc2(y)
        if not classification:
            return y
        yb = self.fc1b(x)
        yb = self.relu(yb)
        yb = self.fc2b(yb)
        return y + yb


def generate_train_bg_features(n, d):
    return torch.rand(n, d)


def generate_test_bg_features(n, d):
    return -torch.rand(n, d)


def randomize_train_bg_features(train_x, input_transform):
    train_x_no_transform = train_x @ input_transform.T.cuda(GID)
    train_x_no_transform[:, dc:] = generate_train_bg_features(N, ds).cuda(GID)
    train_x = train_x_no_transform @ input_transform.cuda(GID)
    return train_x


def generate_toy_data(n, input_transform=None, output_transform=None, sigma=1.0, nonlinear=False):
    """
    Generate toy data for classification
    :param n: number of data points
    :param dc: dimension of core features
    :param ds: dimension of background features
    :param sigma: standard deviation of Gaussian noise
    :return: Xc, Xs
    """
    if classification:
        assert dc % 2 == 0
        N_pos = N // 2
        N_neg = N - N_pos
        y = torch.cat([torch.ones(N_pos), -torch.ones(N_neg)], dim=0)
        # positive examples
        Xc_pos = 1 * torch.rand(N_pos, dc)
        Xc_neg = -1 * torch.rand(N_neg, dc)
        Xc = torch.cat([Xc_pos, Xc_neg], dim=0)
        Xs_train = generate_train_bg_features(N, ds)
        # Xs_train = torch.rand(N, ds) + 1
        # Xs_train = 0. * torch.rand(N, ds)
        # Xs_test = 3 * torch.rand(N, ds) + 1
        # Xs_test = torch.rand(N, ds) + 3
        Xs_test = generate_test_bg_features(N, ds)
        train_x = torch.cat([Xc, Xs_train], dim=-1)
        test_x = torch.cat([Xc, Xs_test], dim=-1)
        print('Train Xc:', Xc, 'Train Xs:', Xs_train, 'Test Xs:', Xs_test, 'y:', y)
        print('Max:', train_x.max(dim=0)[0], 'Min:', train_x.min(dim=0)[0], 'Mean:', train_x.mean(dim=0), 'STD:', train_x.std(dim=0))
        if input_transform is not None:
            train_x = train_x @ input_transform
            test_x = test_x @ input_transform
        return train_x, test_x, y


        # assert dc % 2 == 0
        # N_pos = N // 2
        # N_neg = N - N_pos
        # y = torch.cat([torch.ones(N_pos), -torch.ones(N_neg)], dim=0)
        # # positive examples
        # Xc1_pos = 1 * torch.rand(N_pos, dc // 2)
        # Xc1_neg = 0. * torch.rand(N_neg, dc // 2)
        # Xc2_pos = 0. * torch.rand(N_pos, dc // 2)
        # Xc2_neg = 1 * torch.rand(N_neg, dc // 2)
        # Xc1 = torch.cat([Xc1_pos, Xc1_neg], dim=0)
        # Xc2 = torch.cat([Xc2_pos, Xc2_neg], dim=0)
        # Xs_train = generate_train_bg_features(N, ds)
        # # Xs_train = torch.rand(N, ds) + 1
        # # Xs_train = 0. * torch.rand(N, ds)
        # # Xs_test = 3 * torch.rand(N, ds) + 1
        # # Xs_test = torch.rand(N, ds) + 3
        # Xs_test = generate_test_bg_features(N, ds)
        # train_x = torch.cat([Xc1, Xc2, Xs_train], dim=-1)
        # test_x = torch.cat([Xc1, Xc2, Xs_test], dim=-1)
        # print('Train Xc1:', Xc1, 'Train Xc2:', Xc2, 'Train Xs:', Xs_train, 'Test Xs:', Xs_test, 'y:', y)
        # print('Max:', train_x.max(dim=0)[0], 'Min:', train_x.min(dim=0)[0], 'Mean:', train_x.mean(dim=0), 'STD:', train_x.std(dim=0))
        # if input_transform is not None:
        #     train_x = train_x @ input_transform
        #     test_x = test_x @ input_transform
        # return train_x, test_x, y
    else:
        # Xc = 5 * sigma * torch.rand(n, dc)
        Xc = sigma * torch.randn(n, dc)
        print('Xc: ', Xc)
        # Xs_train = sigma * torch.randn(1, ds).expand(n, ds)
        # Xs_train = sigma * torch.randn(n, ds)
        # Xs_train = sigma * torch.zeros(n, ds)
        # Xs_train = 2 * torch.randn(1, ds).expand(n, ds) + sigma * torch.randn(n, ds)
        Xs_train = 3 + 0.5 * sigma * torch.randn(n, ds)
        # Xs_train = sigma * torch.cat([torch.randn(1, 1).expand(n, 1), torch.zeros(n, ds-1)], dim=-1)
        # Xs_train = 10 * sigma * torch.rand(n, ds)
        # Xs_test = 10 * sigma * torch.rand(n, ds) + 10
        # Xs_test = sigma * torch.zeros(n, ds)
        # Xs_test = torch.rand(n, ds) + 1
        Xs_test = sigma * torch.randn(n, ds)
        print('Train: ', Xs_train)
        print('Test: ', Xs_test)
        # Xs_test = project(Xs_test, Xs_train[0, :])
        # Xs_test = Xs_test - project(Xs_test, Xs_train[0, :])
        # perturb = Xs_test - project(Xs_test, Xs_train[0, :])
        # Xs_test += 10 * perturb
        # print('Test after projection to train: ', Xs_test)
        # print(Xs_train)
        print(Xs_train.min(dim=0)[0], Xs_train.max(dim=0)[0])
        # if nonlinear:
        #     Xc = nonlinear_transform(Xc)
        #     Xs_train = nonlinear_transform(Xs_train)
        #     Xs_test = nonlinear_transform(Xs_test)
        train_x = torch.cat([Xc, Xs_train], dim=-1)
        test_x = torch.cat([Xc, Xs_test], dim=-1)
        if input_transform is not None:
            train_x = train_x @ input_transform
            print('Max:', train_x.max(dim=0)[0], 'Min:', train_x.min(dim=0)[0], 'Mean:', train_x.mean(dim=0), 'STD:', train_x.std(dim=0))
            test_x = test_x @ input_transform
        y = Xc
        if output_transform is not None:
            y = y @ output_transform
        print('Singular values of training data:', LA.svd(train_x, full_matrices=False)[1])
        print('Y: ', y)
        # Xs_train = 0.0 * sigma * torch.randn(1, ds).expand(n, ds)
        # return nonlinear_transform(train_x), nonlinear_transform(test_x), y
        return train_x, test_x, y


def random_orthogonal_transform(input_size, output_size):
    # return torch.eye(input_size, output_size)
    random_matrix = torch.randn(input_size, output_size)
    svd = LA.svd(random_matrix, full_matrices=False)
    if input_size >= output_size:
        return svd[0]  # return U
    return svd[2]  # return VH


def random_nonnegative_transform(input_size, output_size):
    # return torch.eye(input_size, output_size)
    return torch.rand(input_size, output_size) / 5


def nonlinear_transform(x):
    return x.sign() * (x.abs() ** (1/3))


def project(x, target):
    # project each row in x to direction of the target vector
    return (x * target).sum(dim=-1, keepdim=True) * target / (target.norm(dim=-1, keepdim=True) ** 2)


def print_function(model, fc1_name, fc2_name):
    first_layer_weight = getattr(model, fc1_name).weight.detach().cpu().numpy()
    second_layer_weight = getattr(model, fc2_name).weight.detach().cpu().numpy()
    for i in range(first_layer_weight.shape[0]):
        multinomial = ''
        for j in range(first_layer_weight.shape[1]):
            multinomial += f'{first_layer_weight[i, j]:.2f} * x{j}'
            if j < first_layer_weight.shape[1] - 1:
                multinomial += ' + '
        print(f'z{i} = ReLU({multinomial})')
    print('')
    for i in range(second_layer_weight.shape[0]):
        multinomial = ''
        for j in range(second_layer_weight.shape[1]):
            multinomial += f'{second_layer_weight[i, j]:.2f} * z{j}'
            if j < second_layer_weight.shape[1] - 1:
                multinomial += ' + '
        print(f'y{i} = {multinomial}')


class LogisticLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y):
        return torch.mean(torch.log(1 + torch.exp(-y * y_pred)))


class HingeLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, y_pred, y):
        return torch.mean(torch.clamp(1 - y * y_pred, min=0))


def main():
    # set_seed(0)

    input_transform = random_orthogonal_transform(dc + ds, id)
    # input_transform = torch.eye(dc + ds)
    # output_transform = random_orthogonal_transform(dc, od)
    # output_transform = random_nonnegative_transform(dc, od)
    if classification:
        output_transform = None
    else:
        print(f'Ground truth: y = x @ {output_transform.t()}')
    
    x_train, x_test, y = generate_toy_data(N, input_transform=input_transform, output_transform=output_transform,
                                           sigma=1.0, nonlinear=False)

    x_train, x_test, y = x_train.cuda(GID), x_test.cuda(GID), y.cuda(GID)
    
    model = TwoLayerReLUNet(id, hd, od).cuda(GID)

    if freeze_2nd:
        params = list(model.fc1.parameters()) + list(model.fc1b.parameters())
    else:
        params = model.parameters()
    
    assert optim_name.lower() in ['sgd', 'adam', 'adamw']

    if optim_name.lower() == 'sgd':
        optimizer = optim.SGD(params, lr=LR, momentum=0.9, weight_decay=WD)
    elif optim_name.lower() == 'adam':
        optimizer = optim.Adam(params, lr=LR, weight_decay=WD, eps=1e-4)
    else:  # adamw
        optimizer = optim.AdamW(params, lr=LR, weight_decay=WD, eps=1e-3)
    
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ITERS, eta_min=1e-6)

    if classification:
        criterion = HingeLoss() if hinge else LogisticLoss()
    else:
        criterion = nn.MSELoss()
    
    model.train()

    train_errs = []
    eval_errs = []
    randomized_errs = []
    randomized_output_gaps = []
    dc_weight_norms = []
    ds_weight_norms = []

    correlations = None
    corr_pos_examples = None
    corr_neg_examples = None

    if freeze_2nd:
        correlations = []
        corr_pos_examples = []
        corr_neg_examples = []

    for i in range(ITERS):
        # x_train, x_test, y = generate_toy_data(N, input_transform=input_transform, output_transform=output_transform,
        #                                        sigma=1.0, nonlinear=False)
        # x_train, x_test, y = x_train.cuda(GID), x_test.cuda(GID), y.cuda(GID)

        y_pred = model(x_train)

        if classification:
            y_pred = y_pred.squeeze()
        
        loss = criterion(y_pred, y)

        if i % print_freq == 0 or i == ITERS - 1:
            train_loss = loss.item()
            randomized_train_loss, y_randomized = randomize_bg_test(model, x_train, y, input_transform, criterion)
            eval_loss = evaluate(model, x_test, y, criterion)
            train_errs.append(train_loss)
            eval_errs.append(eval_loss)
            randomized_errs.append(np.abs(randomized_train_loss - train_loss))
            randomized_output_gaps.append(torch.abs(y_pred - y_randomized).mean().item())
            weight = model.fc1.weight.cpu().detach()
            if input_transform is not None:
                weight = weight @ input_transform.T
                if freeze_2nd:
                    correlation = torch.zeros(2, hd, dc+ds)
                    correlation[0, ...] = weight
                    weightb = model.fc1b.weight.cpu().detach()
                    correlation[1, ...] = weightb @ input_transform.T
                    correlations.append(correlation)
                    corr_pos_example = torch.zeros(2, hd)
                    corr_pos_example[0, :] = 0.5 * weight[:, :dc].sum(dim=-1) + 0.5 * weight[:, dc:].sum(dim=-1)
                    corr_pos_example[1, :] = 0.5 * weightb[:, :dc].sum(dim=-1) + 0.5 * weightb[:, dc:].sum(dim=-1)
                    corr_neg_example = torch.zeros(2, hd)
                    corr_neg_example[0, :] = -0.5 * weight[:, :dc].sum(dim=-1) + 0.5 * weight[:, dc:].sum(dim=-1)
                    corr_neg_example[1, :] = -0.5 * weightb[:, :dc].sum(dim=-1) + 0.5 * weightb[:, dc:].sum(dim=-1)
                    corr_pos_examples.append(corr_pos_example)
                    corr_neg_examples.append(corr_neg_example)
            dc_weight_norms.append(LA.norm(weight[:, :dc], dim=-1).mean().item())
            ds_weight_norms.append(LA.norm(weight[:, dc:], dim=-1).mean().item())

            # dc_weight_norms.append(LA.norm(weight[:, :dc].sum(dim=-1), dim=-1).item())
            # ds_weight_norms.append(LA.norm(weight[:, dc:].sum(dim=-1), dim=-1).item())

            model.train()
            print(f'Iteration {i}: train loss = {train_loss:.4f}, eval loss = {eval_loss:.4f}')
            # print(f'Output: {y_pred[:10]}, randomized output: {y_randomized[:10]}')
            if i == 0:
                print(f'Init core weights: {weight[:10, :dc]}')
                print(f'Init background weights: {weight[:10, dc:]}')
            elif i == ITERS - 1:
                print(f'Final core weights: {weight[:10, :dc]}')
                print(f'Final background weights: {weight[:10, dc:]}')
                print(f'Top-10 singular values of the weight matrix: {LA.svd(weight, full_matrices=False)[1][:10]}')
                print(f'Training error: {train_loss:.4f}, Eval error: {eval_loss:.4f}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # scheduler.step()
    
    if correlations is not None:
        correlations = torch.stack(correlations, dim=0)
        corr_pos_examples = torch.stack(corr_pos_examples, dim=0)
        corr_neg_examples = torch.stack(corr_neg_examples, dim=0)
    
    # correlations = None
    
    plot_loss_and_weights(dc_weight_norms, ds_weight_norms, train_errs, eval_errs, randomized_errs, randomized_output_gaps,
                          correlations, corr_pos_examples, corr_neg_examples)
    # print_function(model, fc1_name='fc1', fc2_name='fc2')
    # if freeze_2nd:
        # print_function(model, fc1_name='fc1b', fc2_name='fc2b')


def evaluate(model, x_test, y_test, criterion):
    model.eval()
    with torch.no_grad():
        y_pred = model(x_test)
        if classification:
            y_pred = y_pred.squeeze()
        loss = criterion(y_pred, y_test)
    return loss.item()


def randomize_bg_test(model, x_train, y_train, input_transform, criterion):
    model.eval()
    with torch.no_grad():
        x_train = randomize_train_bg_features(x_train, input_transform)
        y_pred = model(x_train)
        if classification:
            y_pred = y_pred.squeeze()
        loss = criterion(y_pred, y_train)
    return loss.item(), y_pred


def plot_loss_and_weights(dc_weight_norms, ds_weight_norms, train_errs, eval_errs, randomized_errs, randomized_output_gaps,
                          correlations=None, corr_pos_examples=None, corr_neg_examples=None):
    # sns.set_style('whitegrid')
    f, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(6, 7))
    # plt.suptitle(f'{"Linear, " if linear else "ReLU, "}dc={dc}, ds={ds}, id={id}, od={od}, hd={hd}, wd={WD}, opt={optim_name}{", freeze2nd" if freeze_2nd else ""}')
    plt.suptitle('Two-layer ReLU network')
    # plt.suptitle('Linear')
    # plot loss
    ax1.plot(print_freq * np.arange(len(train_errs)), train_errs, label='ID test')
    ax1.plot(print_freq * np.arange(len(eval_errs)), eval_errs, label='OOD test')
    # ax1.plot(print_freq * np.arange(len(randomized_errs)), randomized_errs, label='randomized loss gap')
    # ax1.plot(print_freq * np.arange(len(randomized_output_gaps)), randomized_output_gaps, label='randomized output gap')
    # ax1.set_xlim((0, ITERS))
    ax1.set_ylim((-0.0, max(np.max(train_errs), np.max(eval_errs)) + 0.1))
    if classification:
        ax1.set_ylabel('Classification loss')
    else:
        ax1.set_ylabel('MSE')
    ax1.legend()
    ax1.set_title(f'Risk')
    ax1.grid()

    # plot weights
    ax2.plot(print_freq * np.arange(len(dc_weight_norms)), dc_weight_norms, label='Core')
    ax2.plot(print_freq * np.arange(len(ds_weight_norms)), ds_weight_norms, label='Background')
    # ax2.set_xlim((0, ITERS))
    ax2.set_ylim((0., max(np.max(dc_weight_norms), np.max(ds_weight_norms)) + 0.05))
    ax2.set_xlabel('Iterations')
    ax2.set_ylabel('Mean weight l2 norm')
    ax2.legend()
    ax2.set_title('Weight norm')

    plt.tight_layout()

    if classification:
        save_dir = 'results/toy/classification/'
    else:
        save_dir = 'results/toy/'
    plt.savefig(save_dir+f'{"linear_" if linear else "relu_"}dc{dc}_ds{ds}_id{id}_od{od}_hd{hd}_wd{WD}_{optim_name.lower()}{"_freeze2nd" if freeze_2nd else ""}_iter{ITERS}.pdf')

    if correlations is not None:
        fig = plt.figure(figsize=(40, 10))
        # plt.title(f'{"Linear, " if linear else "ReLU, "}dc={dc}, ds={ds}, id={id}, od={od}, hd={hd}, wd={WD}, opt={optim_name}{", freeze2nd" if freeze_2nd else ""}', fontdict={'fontsize': 20})
        plt.axis('off')
        if plot_corr_sum:
            assert plot_random_neurons
            # outer = fig.add_gridspec(1, 1, wspace=0.05, hspace=0.2, left=0.02, right=0.98, bottom=0.15, top=0.97)
            # inner = outer[0].subgridspec(nrows=3, ncols=10, wspace=0.08, hspace=0.25)
            # axs = inner.subplots(sharey=True)
            # # randomly sample 10 neurons to plot
            # random_neurons = np.random.choice(2*hd, size=10, replace=False)
            # for k in range(10):
            #     # correlation between the k-th sampled neuron and
            #     # 1st row: sum of dc features for the first class
            #     # 2nd row: sum of dc features for the second class
            #     # 3rd row: sum of bg features
            #     pos_corrs = correlations[:, int(random_neurons[k] >= hd), random_neurons[k] % hd, :dc//2].sum(dim=-1)
            #     neg_corrs = correlations[:, int(random_neurons[k] >= hd), random_neurons[k] % hd, dc//2:dc].sum(dim=-1)
            #     bg_corrs = correlations[:, int(random_neurons[k] >= hd), random_neurons[k] % hd, dc:].sum(dim=-1)
            #     axs[0, k].plot(print_freq * np.arange(correlations.size(0)), pos_corrs.numpy())
            #     axs[1, k].plot(print_freq * np.arange(correlations.size(0)), neg_corrs.numpy())
            #     axs[2, k].plot(print_freq * np.arange(correlations.size(0)), bg_corrs.numpy())
            #     for j in range(3):
            #         if j == 0:
            #             # axs[j, k].set_title(f'{k+1}-th neuron', fontsize=20)
            #             pass
            #         if k == 0:
            #             if j in [0, 1]:
            #                 axs[j, k].set_ylabel(f'Corr. w/ class {1 if j == 0 else -1} examples', fontsize=16)
            #             else:  # 3rd row
            #                 axs[j, k].set_ylabel(f'Corr. w/ class {1 if j == 0 else -1} examples', fontsize=16)
            #         axs[j, k].grid()
            outer = fig.add_gridspec(1, 1, wspace=0.05, hspace=0.2, left=0.02, right=0.98, bottom=0.15, top=0.97)
            inner = outer[0].subgridspec(nrows=2, ncols=10, wspace=0.08, hspace=0.25)
            axs = inner.subplots(sharey=True)
            # randomly sample 10 neurons to plot
            random_neurons = np.random.choice(2*hd, size=10, replace=False)
            print(random_neurons)
            for k in range(10):
                # correlation between the k-th sampled neuron and
                # 1st row: mean correlations to examples from the first class
                # 2nd row: mean correlations to examples from the second class
                pos_corrs = corr_pos_examples[:, int(random_neurons[k] >= hd), random_neurons[k] % hd]
                neg_corrs = corr_neg_examples[:, int(random_neurons[k] >= hd), random_neurons[k] % hd]
                axs[0, k].plot(print_freq * np.arange(corr_pos_examples.size(0)), pos_corrs.numpy())
                axs[1, k].plot(print_freq * np.arange(corr_neg_examples.size(0)), neg_corrs.numpy())
                for j in range(2):
                    if j == 0:
                        # axs[j, k].set_title(f'{k+1}-th neuron', fontsize=20)
                        pass
                    if k == 0:
                        if j in [0, 1]:
                            axs[j, k].set_ylabel(f'Corr. w/ class {1 if j == 0 else -1} examples', fontsize=18)
                        else:  # 3rd row
                            axs[j, k].set_ylabel(f'Corr. w/ class {1 if j == 0 else -1} examples', fontsize=18)
                    axs[j, k].grid()
        else:
            outer = fig.add_gridspec(2, 1, wspace=0.05, hspace=0.2, left=0.03, right=0.98, bottom=0.03, top=0.85)
            for i in range(2):
                if plot_corr_sum_bg:
                    inner = outer[i].subgridspec(nrows=2, ncols=hd, wspace=0.08, hspace=0.5)
                    axs = inner.subplots(sharey=True)
                    for k in range(hd):
                        # correlation between the k-th neuron and
                        # 0: dc feature for the first class + bg feature
                        # 1: dc feature for the second class + bg feature
                        pos_corrs = correlations[:, i, k, :dc//2].sum(dim=-1) + correlations[:, i, k, dc:].sum(dim=-1)
                        neg_corrs = correlations[:, i, k, dc//2:].sum(dim=-1)
                        axs[0, k].plot(print_freq * np.arange(correlations.size(0)), pos_corrs.numpy())
                        axs[1, k].plot(print_freq * np.arange(correlations.size(0)), neg_corrs.numpy())

                        for j in range(2):
                            axs[j, k].set_title(f'{i}th sub-network, {k}th neuron, {"positive" if j == 0 else "negative"}')
                            if k == 0:
                                axs[j, k].set_ylabel('Correlation')
                            axs[j, k].grid()
                else:  # plot correlations between each neuron and individual feature vectors
                    inner = outer[i].subgridspec(nrows=dc+ds, ncols=hd, wspace=0.08, hspace=0.5)
                    axs = inner.subplots(sharey=True)
                    for j in range(dc + ds):
                        for k in range(hd):
                            # correlation between the j-th feature and the k-th neuron
                            axs[j, k].plot(print_freq * np.arange(correlations.size(0)), correlations[:, i, k, j].numpy())
                            axs[j, k].set_title(f'{i}-th sub-network, {j}-th feature, {k}-th neuron')
                            if k == 0:
                                axs[j, k].set_ylabel('Correlation')
                            axs[j, k].grid()
        
        plt.savefig(save_dir+f'{"linear_" if linear else "relu_"}dc{dc}_ds{ds}_id{id}_od{od}_hd{hd}_wd{WD}_{optim_name.lower()}{"_freeze2nd" if freeze_2nd else ""}_iter{ITERS}_CORR{"_SUM" if plot_corr_sum else ""}.pdf')


if __name__ == '__main__':
    main()
