import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import os
import pandas as pd
import numpy as np

def unflatten_weights(flat_weights_np, model_layout, device):   
    param_dict = {}
    print("\n--- Un-flattening weights ---")
    for layer_info in model_layout:
        varname = layer_info['varname']
        shape = layer_info['shape']
        start_idx = layer_info['start_idx']
        end_idx = layer_info['end_idx']

        param_np = flat_weights_np[start_idx:end_idx]
        expected_elements = np.prod(shape)
        if param_np.size != expected_elements:
             print(f"Warning: Size mismatch for {varname}. Expected {expected_elements}, got {param_np.size}. Reshaping might fail.")
        try:
            param_tensor = torch.from_numpy(param_np).view(shape).to(device)
            param_dict[varname] = param_tensor
            print(f"  Extracted {varname}, shape: {param_tensor.shape}")
        except RuntimeError as e:
            print(f"ERROR reshaping {varname} with shape {shape} from slice size {param_np.size}: {e}")
            param_dict[varname] = torch.zeros(shape, device=device)

    layer_kernels = {}
    layer_biases = {}
    ordered_prefixes = []
    processed_prefixes = set()

    for layer_info in model_layout:
        varname = layer_info['varname']
        parts = varname.rsplit('/', 1)
        if len(parts) == 2:
            prefix, param_type_raw = parts
            param_type = param_type_raw.split(':')[0] 
        else:
            prefix = varname
            param_type = 'unknown'
            if 'bias' in varname.lower(): param_type = 'bias'
            elif 'kernel' in varname.lower() or 'weight' in varname.lower(): param_type = 'kernel'
            print(f"Warning: Could not determine prefix/type reliably for {varname}. Assuming prefix='{prefix}', type='{param_type}'")


        if prefix not in processed_prefixes:
            ordered_prefixes.append(prefix)
            processed_prefixes.add(prefix)

        if varname in param_dict:
            tensor = param_dict[varname]
            if param_type == 'kernel':
                if tensor.ndim == 4: # Conv Kernel [H, W, Cin, Cout] -> [Cout, Cin, H, W]
                    tensor = tensor.permute(3, 2, 0, 1).contiguous()
                    print(f"    Transposed Conv Kernel {prefix}: {tensor.shape}")
                elif tensor.ndim == 2: # Dense Kernel [Cin, Cout] -> [Cout, Cin]
                    tensor = tensor.T.contiguous()
                    print(f"    Transposed Dense Kernel {prefix}: {tensor.shape}")
                layer_kernels[prefix] = tensor
            elif param_type == 'bias':
                layer_biases[prefix] = tensor
        else:
             print(f"Warning: Parameter {varname} not found in extracted param_dict. Skipping.")


    raw_weights = []
    raw_biases = []
    for prefix in ordered_prefixes:
        if prefix in layer_kernels:
            raw_weights.append(layer_kernels[prefix])
            bias = layer_biases.get(prefix, None)
            raw_biases.append(bias)
            if bias is None:
                print(f"Warning: Kernel found for {prefix} but no bias.")
        elif prefix in layer_biases:
             print(f"Warning: Bias found for {prefix} but no kernel. Bias will be ignored unless handled specifically.")
            
    return raw_weights, raw_biases

class DifferentiableCNN(nn.Module):
    """
    A CNN architecture where batches of weights and biases are provided externally
    in the forward pass, allowing for differentiable parameter setting.
    You can specify the activation function when initializing the class.
    """
    def __init__(self, activation=F.relu):
        super().__init__()
        # Default activation function is ReLU, can be changed at initialization
        self.activation = activation

        # Define CNN architecture components
        self.c1_in_channels = 1
        self.c1_out_channels = 16
        self.c1_kernel_size_hw = (3, 3)
        self.c1_stride = 2
        self.c1_padding = 0

        self.c2_in_channels = 16
        self.c2_out_channels = 16
        self.c2_kernel_size_hw = (3, 3)
        self.c2_stride = 2
        self.c2_padding = 0

        self.c3_in_channels = 16
        self.c3_out_channels = 16
        self.c3_kernel_size_hw = (3, 3)
        self.c3_stride = 2
        self.c3_padding = 0

        self.fc_in_features = 16
        self.fc_out_features = 10

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def _parse_params(self, node_features, edge_features, input_dim_nodes_skip=1):
        """Helper function to parse batched weights and biases."""
        batch_size = node_features.shape[0]

        # --- Parse Biases ---
        current_bias_offset = input_dim_nodes_skip
        b_conv1 = node_features[:, current_bias_offset : current_bias_offset + self.c1_out_channels, 0]
        current_bias_offset += self.c1_out_channels
        b_conv2 = node_features[:, current_bias_offset : current_bias_offset + self.c2_out_channels, 0]
        current_bias_offset += self.c2_out_channels
        b_conv3 = node_features[:, current_bias_offset : current_bias_offset + self.c3_out_channels, 0]
        current_bias_offset += self.c3_out_channels
        b_fc = node_features[:, current_bias_offset : current_bias_offset + self.fc_out_features, 0]

        # --- Parse Weights ---
        current_edge_idx = 0
        
        c1_kh, c1_kw = self.c1_kernel_size_hw
        num_kernels_c1 = self.c1_in_channels * self.c1_out_channels
        kernels_c1_flat = edge_features[:, current_edge_idx : current_edge_idx + num_kernels_c1, :c1_kh * c1_kw]
        w_c1 = kernels_c1_flat.reshape(batch_size, self.c1_in_channels, self.c1_out_channels, c1_kh, c1_kw).permute(0, 2, 1, 3, 4)
        current_edge_idx += num_kernels_c1

        c2_kh, c2_kw = self.c2_kernel_size_hw
        num_kernels_c2 = self.c2_in_channels * self.c2_out_channels
        kernels_c2_flat = edge_features[:, current_edge_idx : current_edge_idx + num_kernels_c2, :c2_kh * c2_kw]
        w_c2 = kernels_c2_flat.reshape(batch_size, self.c2_in_channels, self.c2_out_channels, c2_kh, c2_kw).permute(0, 2, 1, 3, 4)
        current_edge_idx += num_kernels_c2

        c3_kh, c3_kw = self.c3_kernel_size_hw
        num_kernels_c3 = self.c3_in_channels * self.c3_out_channels
        kernels_c3_flat = edge_features[:, current_edge_idx : current_edge_idx + num_kernels_c3, :c3_kh * c3_kw]
        w_c3 = kernels_c3_flat.reshape(batch_size, self.c3_in_channels, self.c3_out_channels, c3_kh, c3_kw).permute(0, 2, 1, 3, 4)
        current_edge_idx += num_kernels_c3

        num_weights_fc = self.fc_in_features * self.fc_out_features
        weights_fc_flat = edge_features[:, current_edge_idx : current_edge_idx + num_weights_fc, 0]
        w_fc = weights_fc_flat.reshape(batch_size, self.fc_in_features, self.fc_out_features).permute(0, 2, 1)

        return w_c1, b_conv1, w_c2, b_conv2, w_c3, b_conv3, w_fc, b_fc

    def _single_cnn_pass(self, imgs, w1, b1, w2, b2, w3, b3, wf, bf):
        """A helper that runs a single forward pass for one set of parameters."""
        imgs = imgs.to(b1.dtype)
        y = self.activation(F.conv2d(imgs, w1, b1, stride=self.c1_stride, padding=self.c1_padding))
        y = self.activation(F.conv2d(y, w2, b2, stride=self.c2_stride, padding=self.c2_padding))
        y = self.activation(F.conv2d(y, w3, b3, stride=self.c3_stride, padding=self.c3_padding))
        y = self.global_avg_pool(y)
        y = torch.flatten(y, 1)
        y = F.linear(y, wf, bf)
        return y

    def forward(self, x: torch.Tensor, node_features: torch.Tensor, edge_features: torch.Tensor, input_dim_nodes_skip: int = 1) -> torch.Tensor:
        """
        Performs a batched forward pass using torch.func.vmap.
        """
        # Parse all weights and biases in a batched manner
        params = self._parse_params(node_features, edge_features, input_dim_nodes_skip)
        
        # Use vmap to apply the CNN pass for each set of parameters in the batch.
        output = torch.func.vmap(self._single_cnn_pass, in_dims=(None, 0, 0, 0, 0, 0, 0, 0, 0))(x, *params)
        
        # Reshape the output from (graph_batch_size, img_batch_size, num_classes)
        # to (graph_batch_size * img_batch_size, num_classes) for the loss function.
        return output.reshape(-1, self.fc_out_features)

    def sum_abs_params(self, node_features: torch.Tensor, edge_features: torch.Tensor, input_dim_nodes_skip: int = 1) -> torch.Tensor:
        """
        Computes the sum of absolute values of all CNN parameters for each item in the batch.
        """
        batch_size = node_features.shape[0]
        total_sum_per_item = torch.zeros(batch_size, device=node_features.device)

        # --- Biases from node_features ---
        num_biases = self.c1_out_channels + self.c2_out_channels + self.c3_out_channels + self.fc_out_features
        all_biases = node_features[:, input_dim_nodes_skip : input_dim_nodes_skip + num_biases, 0]
        total_sum_per_item += all_biases.abs().sum(dim=1)

        # --- Weights from edge_features ---
        total_sum_per_item += edge_features.abs().sum(dim=[1, 2])

        return total_sum_per_item