from grok_adversarial.dataloaders import get_LC_samples
from ml_collections import config_dict
from grok_adversarial.configs import config_resnet18_cifar10
from grok_adversarial.utils import add_hooks_preact_resnet18
from grok_adversarial.local_complexity import get_intersections_for_hulls
import torch
import torch.nn.functional as F
import torch
import sys

import matplotlib.pyplot as plt
import os
import gc
from torch.utils.data import Dataset, DataLoader

class RandomMatrixDataset(Dataset):
    def __init__(self, size, normalize, dmin=0, dmax=1, length=100):
        """
        Initializes the dataset.
        :param size: tuple, the size of each matrix (batch_size, ...).
        :param normalize: bool, scale the random data from dmin to dmax.
        :param dmin: int, the minimum value in the matrices.
        :param dmax: int, the maximum value in the matrices.
        :param length: int, the length of the dataset (number of matrices).
        """
        self.size = size
        self.normalize = normalize
        self.dmin = dmin
        self.dmax = dmax
        self.length = length

    def __len__(self):
        """
        Returns the length of the dataset.
        """
        return self.length

    def __getitem__(self, idx):
        """
        Generates a random matrix of the specified size and range.
        :param idx: int, the index of the item (unused).
        :return: torch.Tensor, the generated random matrix.
        """
        # Generate a matrix with values in [0, 1]
        matrix = torch.rand(self.size)
        # Scale and shift to [dmin, dmax]
        if self.normalize:
            matrix = matrix * (self.dmax - self.dmin) + self.dmin
        # Match the output of the normal dataset
        return matrix, -1

def compute_LC(model, train_loader, test_loader):
    config = config_resnet18_cifar10()

    # Get the shape of the batch
    one_iterator = iter(test_loader)
    one_batch = next(one_iterator)[0]
    if hasattr(one_iterator, 'close'): one_iterator.close()

    # Creating the dataset and DataLoader
    rand_dataset = RandomMatrixDataset(size=one_batch.shape[1:], normalize=config.normalize, dmin=config.dmin, dmax=config.dmax, length=config.approx_n+config.LC_batch_size)
    rand_dataloader = DataLoader(rand_dataset, batch_size=config.LC_batch_size, shuffle=True)

    ## make hull dict. the keys of this dict will be used for logging
    ## you can add hulls separately for different classes as well to
    ## keep track of classwise statistics
    LC_loaders = {
        # 'train' : train_loader,
        'test' : test_loader,
        'rand' : rand_dataloader
    }
    
    add_hook_fn = add_hooks_preact_resnet18
    model, layer_names, activation_buffer, hook_handles = add_hook_fn(model, config)

    stats = {}
    # return

    for k, loader in LC_loaders.items():

        # compute number of neurons that intersect hulls
        # using network activations

        with torch.no_grad():

            n_inters, _ = get_intersections_for_hulls(
                            loader,
                            r=config.r_frame,
                            nsamples=config.approx_n,
                            n_frame=config.n_frame,
                            model=model,
                            layer_names=layer_names,
                            activation_buffer=activation_buffer
                        )

        stats[k+'_LC']=(n_inters.cpu())

    for hook_handle in hook_handles:
        hook_handle.remove()

    return stats

def plot_errors(test_model, true_model, ax, meshgrid):
    true_output = true_model(meshgrid)
    predicted_output = test_model(meshgrid)

    error = ((true_output - predicted_output).abs())
    ax.scatter(meshgrid, error)