import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import copy


def quantize_model(model, gamma=1.0, num_of_code_words=8, dithered=False):
    """
    Quantizes the weights of a given PyTorch model using mid-rise or mid-tread quantization.
    If 'dithered' is True, it performs dithered mid-tread quantization.

    Parameters:
    model (torch.nn.Module): The PyTorch model whose weights will be quantized
    gamma (float): Dynamic range for quantization
    R (int): Rate (number of bits) for quantization
    dithered (bool): Flag to enable dithered quantization

    Returns:
    torch.nn.Module: Model with quantized weights
    """

    def quantization(x, gamma, num_of_code_words, dithered):
        """
        Quantization function that performs either mid-rise or dithered mid-tread quantization.

        Parameters:
        x (torch.Tensor): Input tensor (weights)
        gamma (float): Dynamic range
        R (int): Number of bits for quantization
        dithered (bool): Flag to enable dithering

        Returns:
        torch.Tensor: Quantized tensor
        """
        # Calculate the number of quantization levels
        R = 1.0 * np.log2(num_of_code_words)
        L = 2 ** R  # Keeping the same as in your original code

        # Calculate step size
        delta = 2 * gamma / L
        edge = gamma - delta / 2

        if dithered:
            # Generate dither uniformly in [-delta/2, delta/2]
            dither = torch.zeros_like(x, dtype=torch.float32).uniform_(-delta / 2, delta / 2)

            # Add dither to the input signal
            x_dithered = x + dither

            # Mid-tread quantization
            quantized_signal = delta * torch.round(x_dithered / delta)

            # Clip the quantized values to the dynamic range
            quantized_signal = torch.clamp(quantized_signal, -edge, edge)

            # Subtract dither to get the final quantized output
            quantized_signal = quantized_signal - dither
        else:
            # Mid-rise quantization
            quantized_signal = torch.floor((x + gamma) / delta) * delta - gamma + delta / 2

            # Clip the quantized values to the dynamic range
            quantized_signal = torch.clamp(quantized_signal, -edge, edge)

        return quantized_signal

    # Apply quantization to each parameter in the model
    for param in model.parameters():
        param.data = quantization(param.data, gamma, num_of_code_words, dithered)

    return model


def plot_model_params(model):
    """
    Plots the histograms of the model parameters before quantization.

    Parameters:
    model (torch.nn.Module): The PyTorch model whose parameters will be plotted
    """
    params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            params.append(param.data.cpu().numpy().flatten())

    # Create a single array of all parameters
    all_params = np.concatenate(params)

    # Plot histogram
    plt.figure(figsize=(10, 6))
    plt.hist(all_params, bins=1090, color='blue', edgecolor='black', alpha=0.7)
    plt.title("Distribution of Model Parameters Before Quantization")
    plt.xlabel("Parameter Values")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()


def calculate_std_of_model_parameters(model):
    """
    Calculate the standard deviation of the parameters of a given model.

    Args:
    - model (torch.nn.Module): The model whose parameters' standard deviation is to be calculated.

    Returns:
    - float: The standard deviation of the model parameters.
    """
    parameters = []

    # Iterate through all the parameters in the model
    for param in model.parameters():
        if param.requires_grad:
            parameters.append(param.view(-1))

    # Concatenate all parameters into a single tensor
    all_parameters = torch.cat(parameters)

    # Calculate and return the standard deviation
    return torch.std(all_parameters).item()


def lattice_quantize_model(model, gamma=1.0, num_of_code_words=8, dithered=False, device='cpu'):
    """
    Quantizes the weights of a given PyTorch model using lattice quantization.
    If 'dithered' is True, it performs dithered lattice quantization by adding
    and subtracting the dither in the original space, following the logic of the
    LatticeQuantization class.

    Parameters:
    model (torch.nn.Module): The PyTorch model whose weights will be quantized
    gamma (float): Dynamic range for quantization
    R (int): Rate (number of bits) for quantization
    dithered (bool): Flag to enable dithering
    device (str or torch.device): Device to perform computations on

    Returns:
    torch.nn.Module: New model instance with quantized weights
    """
    # Create a deep copy of the model to avoid modifying the original
    quantized_model = copy.deepcopy(model)
    R = 0.5 * np.log2(num_of_code_words)
    # R = num_of_code_words
    # Calculate the number of quantization levels
    L = 2 ** R  # Consistent with your original code

    # Calculate step size
    delta = 2 * gamma / L
    edge = gamma - delta / 2

    # Define the lattice generating matrix (hexagonal lattice)
    hex_mat = np.array([[np.sqrt(3) / 2, 0], [1 / 2, 1]])
    gen_mat = hex_mat / np.linalg.det(hex_mat)
    gen_mat = torch.from_numpy(gen_mat).float().to(device)

    # Precompute the inverse of the generating matrix
    gen_mat_inv = torch.inverse(gen_mat)

    def lattice_quantization(x, delta, edge, gen_mat, gen_mat_inv, dithered):
        """
        Lattice quantization function with dither added in the original space.
        """
        # Flatten the tensor and reshape to pairs
        orig_shape = x.shape
        x_flat = x.view(-1)
        n = x_flat.numel()

        # Pad if necessary to make even number of elements
        if n % 2 != 0:
            x_flat = torch.cat([x_flat, torch.zeros(1, device=x.device, dtype=x.dtype)])
            n += 1

        # Reshape into pairs and transpose to shape [2, N_pairs]
        x_pairs = x_flat.view(-1, 2).t()  # Shape [2, N_pairs]

        if dithered:
            # Generate dither in the orthogonal space
            dither_orthog = torch.zeros_like(x_pairs)
            dither_orthog.uniform_(-delta / 2, delta / 2)

            # Transform dither to original space
            dither = torch.matmul(gen_mat, dither_orthog)

            # Add dither to input in original space
            x_pairs_dithered = x_pairs + dither
        else:
            x_pairs_dithered = x_pairs

        # Convert to orthogonal space
        orthogonal_space = torch.matmul(gen_mat_inv, x_pairs_dithered)

        # Quantize in orthogonal space
        q_orthogonal_space = delta * torch.round(orthogonal_space / delta)

        # Clip quantized values (if necessary)
        q_orthogonal_space = torch.clamp(q_orthogonal_space, -edge, edge)

        # Convert back to original space
        x_quantized_pairs = torch.matmul(gen_mat, q_orthogonal_space)

        if dithered:
            # Subtract dither from quantized data in original space
            x_quantized_pairs = x_quantized_pairs - dither

        # Reshape back to original shape
        x_quantized_flat = x_quantized_pairs.t().contiguous().view(-1)[:n]

        # Remove padding if added
        x_quantized_flat = x_quantized_flat[:orig_shape.numel()]

        x_quantized = x_quantized_flat.view(orig_shape)

        return x_quantized

    # Apply quantization to each parameter in the model
    for param in quantized_model.parameters():
        param_data = param.data.to(device)
        param_quantized = lattice_quantization(param_data, delta, edge, gen_mat, gen_mat_inv, dithered)
        param.data = param_quantized.to(param.data.device)

    return quantized_model


def prune_model_topk_weights(model, K):
    """
    Prunes the weights of a given PyTorch model by keeping only the K largest weights
    (by absolute value) and setting all other weights to zero.

    Parameters:
    model (torch.nn.Module): The PyTorch model to be pruned.
    K (int): The number of largest weights to retain.

    Returns:
    torch.nn.Module: A new model instance with pruned weights.
    """
    # Create a deep copy of the model to avoid modifying the original
    pruned_model = copy.deepcopy(model)

    # Collect all weights in a single tensor
    all_weights = []
    for param in pruned_model.parameters():
        all_weights.append(param.data.view(-1))
    all_weights = torch.cat(all_weights)

    # Check if K is less than or equal to the total number of weights
    total_weights = all_weights.numel()
    if K >= total_weights:
        print("K is greater than or equal to the total number of weights. No pruning performed.")
        return pruned_model

    # Find the threshold value for the Kth largest weight (by absolute value)
    with torch.no_grad():
        # Get the absolute values of all weights
        abs_weights = all_weights.abs()

        # Find the Kth largest weight using torch.topk
        threshold_value = torch.topk(abs_weights, K, largest=True).values.min()

        # Apply the pruning mask to each parameter
        for param in pruned_model.parameters():
            # Create a mask for weights with absolute value >= threshold_value
            mask = param.data.abs() >= threshold_value
            # Set weights below the threshold to zero
            param.data.mul_(mask)

    return pruned_model


def prune_model_random_weights(model, K):
    """
    Prunes the weights of a given PyTorch model by randomly keeping K weights
    and setting all other weights to zero.

    Parameters:
    model (torch.nn.Module): The PyTorch model to be pruned.
    K (int): The number of weights to retain.

    Returns:
    torch.nn.Module: A new model instance with pruned weights.
    """
    # Create a deep copy of the model to avoid modifying the original
    pruned_model = copy.deepcopy(model)

    # Collect all weights in a single tensor
    all_weights = []
    for param in pruned_model.parameters():
        all_weights.append(param.data.view(-1))
    all_weights = torch.cat(all_weights)

    # Check if K is less than or equal to the total number of weights
    total_weights = all_weights.numel()
    if K >= total_weights:
        print("K is greater than or equal to the total number of weights. No pruning performed.")
        return pruned_model

    # Randomly select K indices to keep
    with torch.no_grad():
        # Generate a random permutation of indices
        indices = torch.randperm(total_weights)
        # Select the first K indices
        keep_indices = indices[:K]

        # Create a mask of zeros
        mask = torch.zeros(total_weights, dtype=torch.bool)
        # Set the selected indices to True
        mask[keep_indices] = True

        # Apply the pruning mask to each parameter
        start = 0
        for param in pruned_model.parameters():
            param_numel = param.data.numel()
            param_mask = mask[start:start + param_numel].view(param.data.size())
            # Set weights not in keep_indices to zero
            param.data.mul_(param_mask)
            start += param_numel

    return pruned_model

def calculate_SNR(original_unlearned_model, quantized_model):
    '''
    Calculates the SNR of the original unlearned model and quantized model.
    '''
    orig_parameters = []
    quantized_parameters = []

    # Iterate through all the parameters in the model
    for param in original_unlearned_model.parameters():
        if param.requires_grad:
            orig_parameters.append(param.view(-1))

    for param in quantized_model.parameters():
        if param.requires_grad:
            quantized_parameters.append(param.view(-1))

    # Concatenate all parameters into a single tensor
    all_parameters = torch.cat(orig_parameters)
    all_quantized_parameters = torch.cat(quantized_parameters)

    # Calculate SNR
    SNR = 10.0 * np.log10(torch.var(all_parameters).item() / (torch.var(all_parameters - all_quantized_parameters).item()))
    return SNR