import torch
import torch.nn.functional as F
from .shared import default_conv_kwargs, hyp
from torch import nn

#############################################
#            Network Components             #
#############################################

# More efficient BatchNorm implementation
class BatchNorm(nn.BatchNorm2d):
    def __init__(
        self,
        num_features,
        eps=1e-12,
        momentum=hyp["net"]["batch_norm_momentum"],
        weight=False,
        bias=True,
    ):
        super().__init__(num_features,
                         eps=eps,
                         momentum=momentum,
                         track_running_stats=False)
        # Initialize directly without storing extra tensors
        with torch.no_grad():
            self.weight.fill_(1.0)
            self.bias.fill_(0.0)
        self.weight.requires_grad = weight
        self.bias.requires_grad = bias


# Memory-efficient convolution layer
class Conv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        kwargs = {**default_conv_kwargs, **kwargs}
        super().__init__(*args, **kwargs)
        # Don't store kwargs dictionary to save memory
        # Only store what's necessary


class Linear(nn.Linear):
    def __init__(self, *args, temperature=None, **kwargs):
        super().__init__(*args, **kwargs)
        # Only store temperature if it's used
        if temperature is not None:
            self.register_buffer("temperature", torch.tensor(temperature, dtype=torch.float16))
        else:
            self.temperature = None

    def forward(self, x):
        if self.temperature is not None:
            weight = self.weight * self.temperature
        else:
            weight = self.weight
        return F.linear(x, weight)  # More efficient than matmul


# Memory-efficient implementation of ConvGroup
class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out, use_residual=True):
        super().__init__()
        self.channels_in = channels_in
        self.channels_out = channels_out
        self.use_residual = use_residual and (channels_in == channels_out)

        # Use pooling as a functional call rather than storing module
        self.conv1 = Conv(channels_in, channels_out)
        self.conv2 = Conv(channels_out, channels_out)

        self.norm1 = BatchNorm(channels_out)
        self.norm2 = BatchNorm(channels_out)

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = F.max_pool2d(x, 2)  # Functional instead of stored module
        x = self.norm1(x)
        x = F.gelu(x)  # Functional instead of stored module

        x = self.conv2(x)
        x = self.norm2(x)

        # Skip connection when possible
        if self.use_residual and identity.shape[2:] == x.shape[2:]:
            if identity.shape[1] < x.shape[1]:
                # Channel padding for skip connections
                padding = torch.zeros(
                    identity.shape[0],
                    x.shape[1] - identity.shape[1],
                    identity.shape[2],
                    identity.shape[3],
                    device=identity.device,
                    dtype=identity.dtype,
                )
                identity = torch.cat([identity, padding], dim=1)
            x = x + identity

        x = F.gelu(x)
        return x


# More efficient pooling implementation
def global_max_pooling(x):
    # Use functional implementation to save memory
    return torch.amax(x, dim=(2, 3))


#############################################
#          Init Helper Functions            #
#############################################

def get_patches(x, patch_shape=(3, 3), dtype=torch.float32):
    # Memory-optimized patches extraction
    c, (h, w) = x.shape[1], patch_shape
    # Process in batches if necessary
    batch_size = min(256, x.shape[0])  # Process in smaller batches to conserve memory
    patches_list = []

    for i in range(0, x.shape[0], batch_size):
        batch = x[i:i+batch_size]
        batch_patches = (
            batch.unfold(2, h, 1).unfold(3, w, 1)
            .transpose(1, 3).reshape(-1, c, h, w).to(dtype)
        )
        patches_list.append(batch_patches)

    return torch.cat(patches_list, dim=0)


def get_whitening_parameters(patches, max_samples=10000):
    # Limit the number of patches to conserve memory
    if patches.shape[0] > max_samples:
        indices = torch.randperm(patches.shape[0], device=patches.device)[:max_samples]
        patches = patches[indices]

    n, c, h, w = patches.shape

    # Reshape patches and cast to float32 for numerical stability
    reshaped_patches = patches.reshape(n, c * h * w).to(torch.float32)

    # Use batch covariance calculation for memory efficiency
    batch_size = min(1000, n)
    cov_sum = torch.zeros((c * h * w, c * h * w), device=patches.device, dtype=torch.float32)
    mean = reshaped_patches.mean(0, keepdim=True)

    for i in range(0, n, batch_size):
        batch = reshaped_patches[i:i+batch_size] - mean
        # Update covariance incrementally
        cov_sum += batch.t() @ batch

    est_covariance = cov_sum / (n - 1)

    # Free up memory before eigendecomposition
    del reshaped_patches, batch, cov_sum
    torch.cuda.empty_cache()

    # Compute eigendecomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(est_covariance, UPLO="U")

    return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(
        c * h * w, c, h, w,
    ).flip(0)


# Memory-efficient whitening initialization
def init_whitening_conv(
    layer,
    train_set=None,
    num_examples=None,
    previous_block_data=None,
    pad_amount=None,
    freeze=True,
    whiten_splits=None,
):
    # Get data for whitening
    if train_set is not None and previous_block_data is None:
        if pad_amount > 0:
            indices = torch.randperm(train_set.shape[0], device=train_set.device)[:num_examples]
            previous_block_data = train_set[
                indices, :, pad_amount:-pad_amount, pad_amount:-pad_amount,
            ]
        else:
            indices = torch.randperm(train_set.shape[0], device=train_set.device)[:num_examples]
            previous_block_data = train_set[indices, :, :, :]

    # Use smaller splits to save memory
    if whiten_splits is None or whiten_splits <= 0:
        whiten_splits = 500  # Default to 500 for memory efficiency

    previous_block_data_split = previous_block_data.split(whiten_splits, dim=0)

    # Process in chunks
    eigenvalue_chunks, eigenvector_chunks = [], []
    patch_shape = layer.weight.data.shape[2:]

    for data_split in previous_block_data_split:
        # Convert to half precision for memory efficiency
        patches = get_patches(data_split, patch_shape=patch_shape)
        eigenvalues, eigenvectors = get_whitening_parameters(patches, max_samples=5000)
        eigenvalue_chunks.append(eigenvalues)
        eigenvector_chunks.append(eigenvectors)

        # Free memory between chunks
        del patches, eigenvalues, eigenvectors
        torch.cuda.empty_cache()

    # Average the eigenvalues and eigenvectors
    eigenvalues = torch.stack(eigenvalue_chunks, dim=0).mean(0)
    eigenvectors = torch.stack(eigenvector_chunks, dim=0).mean(0)

    # Free memory
    del eigenvalue_chunks, eigenvector_chunks
    torch.cuda.empty_cache()

    # Set the whitening conv
    set_whitening_conv(
        layer,
        eigenvalues.to(dtype=layer.weight.dtype),
        eigenvectors.to(dtype=layer.weight.dtype),
        freeze=freeze,
    )

    # Calculate output data efficiently - in chunks if needed
    chunk_size = min(512, previous_block_data.shape[0])
    output_chunks = []

    for i in range(0, previous_block_data.shape[0], chunk_size):
        chunk = previous_block_data[i:i+chunk_size].to(dtype=layer.weight.dtype)
        output_chunks.append(layer(chunk).detach())
        # Free memory
        del chunk
        torch.cuda.empty_cache()

    data = torch.cat(output_chunks, dim=0)

    # Free memory
    del output_chunks, previous_block_data, previous_block_data_split
    torch.cuda.empty_cache()

    return data


def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True):
    shape = conv_layer.weight.data.shape
    out_channels = shape[0]

    # Select top filters - we take exactly the number of output channels defined
    # in the convolution layer
    eigenvectors_sliced = (eigenvectors / torch.sqrt(eigenvalues + eps))[-out_channels:, :, :, :]

    # Set weights directly - the original implementation concatenated positive and negative
    # Here we just use the sliced eigenvectors directly
    with torch.no_grad():
        conv_layer.weight.copy_(eigenvectors_sliced)

    # Freeze if required
    if freeze:
        conv_layer.weight.requires_grad = False


#############################################
#            Network Definition             #
#############################################


class SpeedyConvNet(nn.Module):
    def __init__(self, network_dict):
        super().__init__()
        self.net_dict = network_dict
        # Pre-allocate buffers for flip operation to avoid memory allocation during inference
        self.register_buffer("flipped_indices", None, persistent=False)
        self.enable_eval_tta = True # this is to avoid have 2N X 2N influence scores

    def forward(self, x):
        if (not self.training) and self.enable_eval_tta:
            # Use pre-allocated buffer if available
            if self.flipped_indices is None or self.flipped_indices.shape[0] != x.shape[-1]:
                self.flipped_indices = torch.arange(x.shape[-1]-1, -1, -1, device=x.device)
            x_flipped = x.index_select(-1, self.flipped_indices)
            x = torch.cat((x, x_flipped), dim=0)

        # Whitening layer
        x = self.net_dict["initial_block"]["whiten"](x)

        # Apply positives and negatives manually after whitening
        x_pos = F.gelu(x)
        x_neg = F.gelu(-x)
        x = torch.cat([x_pos, x_neg], dim=1)

        # Continue with the rest of the network
        x = self.net_dict["conv_group_1"](x)
        x = self.net_dict["conv_group_2"](x)
        x = self.net_dict["conv_group_3"](x)
        x = global_max_pooling(x)  # Use functional pooling
        x = self.net_dict["linear"](x)

        if not self.training and self.enable_eval_tta:
            # Average predictions
            batch_size = x.shape[0] // 2
            orig, flipped = x[:batch_size], x[batch_size:]
            x = 0.5 * orig + 0.5 * flipped

        return x


def make_net(images_train: torch.Tensor, device: torch.device, base_depth: int) -> SpeedyConvNet:
    # Reduce base depth to save memory
    scaler = 1.6  # Reduced from 2.0
    depths = {
        "init": round(scaler**-1 * base_depth),
        "block1": round(scaler**0 * base_depth),
        "block2": round(scaler**2 * base_depth),
        "block3": round(scaler**3 * base_depth),
        "num_classes": 10,
    }

    # Calculate depths with memory efficiency in mind
    # Important: the output of whiten is doubled when using both positive and negative filters
    kernel_size = hyp["net"]["whitening"]["kernel_size"]
    whiten_conv_depth = 3 * kernel_size ** 2  # Original depth from the base code
    whiten_output_depth = 2 * whiten_conv_depth  # After concatenating positive and negative filters

    # Create network with memory-efficient structure
    network_dict = nn.ModuleDict(
        {
            "initial_block": nn.ModuleDict(
                {
                    "whiten": Conv(
                        3,
                        whiten_conv_depth,
                        kernel_size=kernel_size,
                        padding=0,
                    ),
                    # No need to store activation as module
                },
            ),
            "conv_group_1": ConvGroup(whiten_output_depth, depths["block1"], use_residual=False),
            "conv_group_2": ConvGroup(depths["block1"], depths["block2"], use_residual=True),
            "conv_group_3": ConvGroup(depths["block2"], depths["block3"], use_residual=True),
            # No need to store pooling as module
            "linear": Linear(
                depths["block3"],
                depths["num_classes"],
                bias=False,
                temperature=hyp["opt"]["scaling_factor"],
            ),
        },
    )

    # Create network and convert to appropriate format
    net = SpeedyConvNet(network_dict)
    net = net.to(device)
    net = net.to(memory_format=torch.channels_last)
    net.train()
    net.half()  # Convert to half precision

    # Initialize with memory-efficient whitening
    with torch.no_grad():
        # Random sample for whitening
        random_indices = torch.randperm(
            images_train.shape[0],
            device=images_train.device,
        )[:hyp["net"]["whitening"]["num_examples"]]

        # Initialize whitening with smaller sample
        init_whitening_conv(
            net.net_dict["initial_block"]["whiten"],
            images_train.index_select(0, random_indices),
            num_examples=min(5000, hyp["net"]["whitening"]["num_examples"]),  # Use fewer examples
            pad_amount=hyp["net"]["pad_amount"],
            whiten_splits=500,  # Smaller splits for memory efficiency
        )

        # Modified initialization for conv groups
        for layer_name in net.net_dict.keys():
            if "conv_group" in layer_name:
                # Use more memory-efficient initialization
                with torch.no_grad():
                    # First conv init with bias towards identity
                    std_orig, mean_orig = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)
                    channels_in = net.net_dict[layer_name].conv1.weight.shape[1]
                    channels_out = net.net_dict[layer_name].conv1.weight.shape[0]

                    # Create dirac initialization only for channels that overlap
                    min_channels = min(channels_in, channels_out)
                    net.net_dict[layer_name].conv1.weight.data[:min_channels, :min_channels] += torch.eye(
                        min_channels,
                        device=net.net_dict[layer_name].conv1.weight.device,
                    ).view(min_channels, min_channels, 1, 1)

                    # Renormalize
                    std_new, mean_new = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)
                    net.net_dict[layer_name].conv1.weight.data.sub_(mean_new).div_(std_new).mul_(std_orig).add_(mean_orig)

                    # Second conv - simplified initialization
                    torch.nn.init.kaiming_normal_(net.net_dict[layer_name].conv2.weight, mode="fan_out", nonlinearity="relu")

    # Clean up memory
    torch.cuda.empty_cache()
    return net
