import json

import torch # type: ignore
import torch.nn as nn
import torch.nn.functional as F

from difflogic import LogicLayer, GroupSum, PackBitsTensor, CompiledLogicNet

from library import configs
from library import misc

def input_dim_of_dataset(dataset: configs.Dataset, thresholds: int = None, config=None) -> int:
    if dataset == configs.Dataset.CIFAR10 or dataset == configs.Dataset.CIFAR100 or dataset == configs.Dataset.IMAGENET32 or dataset == configs.Dataset.CUSTOM_IMAGENET:
        assert thresholds is not None
    else:
        thresholds = 0
    
    if dataset == configs.Dataset.SYNTHETIC:
        try:
            return config.data_config.input_size
        except:
            return 784
    return {
        configs.Dataset.ADULT: 116,
        configs.Dataset.BREAST_CANCER: 51,
        configs.Dataset.MONK1: 17,
        configs.Dataset.MONK2: 17,
        configs.Dataset.MONK3: 17,
        configs.Dataset.MNIST: 784, #6272,
        configs.Dataset.FMNIST: 784,
        configs.Dataset.KMNIST: 784,
        configs.Dataset.QMNIST: 784,
        configs.Dataset.EMNIST_BALANCED: 784,
        configs.Dataset.EMNIST_LETTERS: 784,
        configs.Dataset.CUSTOM: 784,
        configs.Dataset.SYNTHETIC: 784,
        configs.Dataset.MNIST20x20: 400,
        configs.Dataset.CIFAR10: 3 * 32 * 32 * thresholds,
        configs.Dataset.CIFAR100: 3 * 32 * 32 * thresholds,
        configs.Dataset.IMAGENET32: 3 * 32 * 32 * thresholds,
        configs.Dataset.CUSTOM_IMAGENET: 3 * 32 * 32 * thresholds,
    }[dataset]


def num_classes_of_dataset(data_config) -> int:
    return {
        configs.Dataset.ADULT: 2,
        configs.Dataset.BREAST_CANCER: 2,
        configs.Dataset.MONK1: 2,
        configs.Dataset.MONK2: 2,
        configs.Dataset.MONK3: 2,
        configs.Dataset.MNIST: 10,
        configs.Dataset.FMNIST: 10,
        configs.Dataset.KMNIST: 10,
        configs.Dataset.QMNIST: 10,
        configs.Dataset.EMNIST_BALANCED: 47,
        configs.Dataset.EMNIST_LETTERS: 26,
        configs.Dataset.CUSTOM: data_config.num_classes,
        configs.Dataset.CUSTOM_IMAGENET: data_config.num_classes,
        configs.Dataset.SYNTHETIC: data_config.num_classes,
        configs.Dataset.MNIST20x20: 10,
        configs.Dataset.CIFAR10: 10,
        configs.Dataset.CIFAR100: 100,
        configs.Dataset.IMAGENET32: 1000,
    }[data_config.dataset]

class SoftToHardANDLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.w = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01)

    def forward(self, x):
        """
        x: (batch_size, input_dim)
        """
        batch_size = x.size(0)
        output_dim = self.w.size(0)

        # Preallocate output
        out = torch.empty(batch_size, output_dim, device=x.device)

        if self.training:
            for o in range(output_dim):
                # Vectorized per-output computation
                prod = x * self.w[o]                    # (batch_size, input_dim)
                out[:, o] = torch.sigmoid(prod).sum(dim=1)  # (batch_size,)
        else:
            w_hard = torch.round(self.w)
            for o in range(output_dim):
                prod = x * w_hard[o]
                out[:, o] = torch.sigmoid(prod).sum(dim=1)

        return out

class TreeNetwork(nn.Module):
    def __init__(self, input_size, n_layers):
        super().__init__()
        layers = []
        in_features = input_size

        for _ in range(n_layers):
            out_features = max(1, in_features // 2)  # Prevent zero-sized layers
            layers.append(LogicLayer(in_features, out_features))
            in_features = out_features

        self.layers = nn.Sequential(*layers)
        self.final_output_size = in_features

    def forward(self, x):
        x = self.layers(x)
        return x.sum(dim=1)  # Sum over the final neurons
"""
class TreeClassificationLayer(nn.Module):
    # Main custom layer.
    def __init__(self, total_input_size, k_classes, n_layers_per_class, tau=10):
        super().__init__()
        assert total_input_size % k_classes == 0, "Input size must be divisible by number of classes"

        self.k = k_classes
        self.tau = tau
        self.chunk_size = total_input_size // k_classes

        self.class_nets = nn.ModuleList([
            TreeNetwork(self.chunk_size, n_layers_per_class)
            for _ in range(k_classes)
        ])

    def forward(self, x):
        # x: Tensor of shape (batch_size, total_input_size)
        outputs = []
        for i in range(self.k):
            x_i = x[:, i * self.chunk_size:(i + 1) * self.chunk_size]
            out_i = self.class_nets[i](x_i)
            outputs.append(out_i.unsqueeze(1))  # shape: [batch_size, 1]

        return torch.cat(outputs, dim=1) / self.tau  # shape: [batch_size, k_classes]
"""

class TreeClassificationLayer(nn.Module):
    # Custom classification layer with optional full input for each TreeNetwork.
    def __init__(self, total_input_size, k_classes, n_layers_per_class, tau=10, full_output=False):
        super().__init__()
        self.k = k_classes
        self.tau = tau
        self.full_output = full_output

        # Decide input size for each TreeNetwork based on `full_output`
        input_size_per_class = total_input_size if full_output else total_input_size // k_classes

        if not full_output:
            assert total_input_size % k_classes == 0, \
                "Input size must be divisible by number of classes when full_output=False"

        self.chunk_size = input_size_per_class  # only meaningful when full_output=False

        self.class_nets = nn.ModuleList([
            TreeNetwork(input_size_per_class, n_layers_per_class)
            for _ in range(k_classes)
        ])

    def forward(self, x):
        outputs = []

        for i in range(self.k):
            if self.full_output:
                x_i = x  # full input
            else:
                x_i = x[:, i * self.chunk_size:(i + 1) * self.chunk_size]

            out_i = self.class_nets[i](x_i)
            outputs.append(out_i.unsqueeze(1))  # shape: [batch_size, 1]

        return torch.cat(outputs, dim=1) / self.tau  # shape: [batch_size, k_classes]

"""
class DistanceLayer(nn.Module):
    def __init__(self, input_dim, num_classes, code_dim=None, tau=1.0, binary=False):
        super(DistanceLayer, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.code_dim = input_dim
        self.tau = tau

        codebook_init = torch.randint(0, 2, (num_classes, self.code_dim)).float()
        self.register_buffer("codebook", codebook_init)

    def forward(self, x):
        dists = torch.cdist(x, self.codebook)
        return -dists / self.tau
"""
class DistanceLayer(nn.Module):
    def __init__(self, input_dim, num_classes, code_dim=None, tau=1.0):
        super(DistanceLayer, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.code_dim = input_dim if code_dim is None else code_dim
        self.tau = tau

        # Store codebook as compact binary (uint8)
        codebook_init = torch.randint(0, 2, (num_classes, self.code_dim), dtype=torch.uint8)
        self.register_buffer("codebook_binary", codebook_init)

    def forward(self, x):
        # Convert binary codebook to float on the fly
        codebook = self.codebook_binary.to(dtype=torch.float32)
        dists = torch.cdist(x, codebook)
        return -dists / self.tau

class MyGroupSum(nn.Module):
    def __init__(self, k: int, num_outputs: int, tau: float = 1., device='cuda'):
        """
        :param k: number of real-valued outputs (e.g., classes)
        :param num_outputs: total number of output neurons (must be divisible by k)
        :param tau: softmax temperature for scaling summed output
        :param device: computation device
        """
        super().__init__()
        self.k = k
        self.tau = tau
        self.device = device

        assert num_outputs % k == 0, f"num_outputs must be divisible by k. Got {num_outputs}, {k}"
        self.num_outputs = num_outputs
        self.group_size = num_outputs // k

        # Learnable per-neuron confidence scores (before sigmoid)
        self.conf_logits = nn.Parameter(torch.zeros(num_outputs, device=device))  # Initialized to 0 → sigmoid ≈ 0.5

    def forward(self, x):
        """
        :param x: shape [..., num_outputs], where the last dim must equal `num_outputs`
        """
        if isinstance(x, PackBitsTensor):
            raise NotImplementedError("PackBitsTensor support not implemented for weighted group sum.")

        assert x.shape[-1] == self.num_outputs, f"Expected input with last dim {self.num_outputs}, got {x.shape[-1]}"

        conf_weights = torch.sigmoid(self.conf_logits)  # shape [num_outputs]

        conf_weights = conf_weights.view(*([1] * (x.ndim - 1)), self.num_outputs)  # shape [..., num_outputs]

        weighted_x = x * conf_weights

        grouped = weighted_x.view(*x.shape[:-1], self.k, self.group_size)  # shape [..., k, group_size]

        return grouped.sum(dim=-1) / self.tau  # shape [..., k]

class DropoutGroupSum(torch.nn.Module):
    """
    The GroupSum module with a dropout-like mask.
    """
    def __init__(self, k: int, tau: float = 1., dropout_percentage=1.0, device='cuda'):
        """
        :param k: number of intended real valued outputs, e.g., number of classes
        :param tau: the (softmax) temperature tau. The summed outputs are divided by tau.
        :param device:
        """
        super().__init__()
        self.k = k
        self.tau = tau
        self.device = device
        self.dp = dropout_percentage

    def forward(self, x):
        if isinstance(x, PackBitsTensor):
            return x.group_sum(self.k)

        assert x.shape[-1] % self.k == 0, (x.shape, self.k)
        
        # Reshape input into (batch_size, num_classes, neurons_per_class)
        x_reshaped = x.reshape(*x.shape[:-1], self.k, x.shape[-1] // self.k)

        if self.training:
            # Generate dropout mask with 50% dropout
            dropout_mask = (torch.rand_like(x_reshaped) > self.dp).float()
            x_reshaped = x_reshaped * dropout_mask

        # Sum across the neurons per class
        output = x_reshaped.sum(-1) / self.tau
        
        return output

def create_model(config: configs.DifflogicConfig) -> torch.nn.Module:
    data_config = config.data_config
    model_config = config.model_config
    misc.set_seed(model_config.seed)

    llkw = dict(grad_factor=model_config.grad_factor, connections=model_config.connections)

    in_dim = input_dim_of_dataset(data_config.dataset, data_config.image_thresholds, config=config)
    print(f"In Dimension: {in_dim}")
    class_count = num_classes_of_dataset(data_config)

    logic_layers = []

    arch = model_config.architecture
    k = model_config.num_neurons
    l = model_config.num_layers

    ####################################################################################################################
    if arch == 'randomly_connected':
        print('randomly_connected', in_dim, k, llkw, l)
        logic_layers.append(torch.nn.Flatten())
        logic_layers.append(LogicLayer(in_dim=in_dim, out_dim=k, **llkw))
        for _ in range(l - 1):
            logic_layers.append(LogicLayer(in_dim=k, out_dim=k, **llkw))

        logic_layers[-1] = LogicLayer(in_dim=k, out_dim=config.model_config.last_layer_neurons, **llkw)

        # assert int(model_config.use_groupsum) + int(model_config.use_ffn) + int(model_config.use_ffbinary) == 1
        
        if model_config.use_groupsum:
            if l == 1:
                logic_layers[-1] = LogicLayer(in_dim=in_dim, out_dim=config.model_config.last_layer_neurons, **llkw)
            else:
                logic_layers[-1] = LogicLayer(in_dim=k, out_dim=config.model_config.last_layer_neurons, **llkw)
            logic_layers.append(GroupSum(class_count, model_config.tau))

        if model_config.use_mygroupsum:
            # logic_layers.append(MyGroupSum(class_count, k, model_config.tau))
            logic_layers.append(DropoutGroupSum(class_count, tau=model_config.tau, dropout_percentage=model_config.dropout_percentage))

        if model_config.tree_classification:
            print(f"Tree Classification...")
            if l == 1:
                logic_layers[-1] = LogicLayer(in_dim=in_dim, out_dim=config.model_config.last_layer_neurons, **llkw)
            else:
                logic_layers[-1] = LogicLayer(in_dim=k, out_dim=config.model_config.last_layer_neurons, **llkw)
            assert model_config.tree_classification or model_config.use_groupsum
            logic_layers.append(TreeClassificationLayer(total_input_size=model_config.last_layer_neurons, 
                                                        k_classes=class_count, 
                                                        n_layers_per_class=model_config.tree_layers, 
                                                        tau=model_config.tau,
                                                        full_output=model_config.full_tree_output))

        if model_config.use_ffn:
            if model_config.use_groupsum:
                logic_layers.pop()
            # First layer: from last_layer_neurons to first hidden layer
            layer_sizes = model_config.used_ffn
            logic_layers.append(torch.nn.Linear(config.model_config.last_layer_neurons, layer_sizes[0]))
            logic_layers.append(torch.nn.LeakyReLU(negative_slope=0.1))
        
            # Intermediate hidden layers
            for i in range(len(layer_sizes) - 1):
                logic_layers.append(torch.nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
                logic_layers.append(torch.nn.LeakyReLU(negative_slope=0.1))
        
            # Final layer: from last hidden layer to class_count
            logic_layers.append(torch.nn.Linear(layer_sizes[-1], class_count))

        if model_config.full_ffn:
            logic_layers = [nn.Flatten(),
                nn.Linear(in_dim, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
    
                nn.Linear(256, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
    
                nn.Linear(256, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
    
                nn.Linear(256, class_count)]

        # if model_config.use_ffbinary:
        #     logic_layers.append(SoftToHardANDLinear(k, class_count))

        if model_config.distanceLayer:
            logic_layers.append(DistanceLayer(k, class_count, tau=model_config.tau))
    
        model = torch.nn.Sequential(*logic_layers)
    elif arch == 'custom_layers' and model_config.custom_layer_sizes != None:
        layers = model_config.custom_layer_sizes # List of layer Sizes
        print('custom_layers', in_dim, layers, llkw, len(layers))
        logic_layers.append(torch.nn.Flatten())
        
        logic_layers.append(LogicLayer(in_dim=in_dim, out_dim=layers[0], **llkw))
        for i in range(len(layers) - 1):
            logic_layers.append(LogicLayer(in_dim=layers[i], out_dim=layers[i+1], **llkw))
        logic_layers.append(LogicLayer(in_dim=layers[-1], out_dim=784, **llkw))

        if model_config.use_groupsum:
            logic_layers.append(GroupSum(class_count, model_config.tau))
        if model_config.use_ffbinary:
            logic_layers.append(SoftToHardANDLinear(layers[-1], class_count))
        model = torch.nn.Sequential(*logic_layers)
        

    #################################################################################################################### 
    else:
        raise NotImplementedError(arch)

    ####################################################################################################################
    try:
        total_num_neurons = sum(map(lambda x: x.num_neurons, logic_layers[1:-1]))
        print(f'total_num_neurons={total_num_neurons}')
        total_num_weights = sum(map(lambda x: x.num_weights, logic_layers[1:-1]))
        print(f'total_num_weights={total_num_weights}')
    except:
        pass

    if model_config.device == 'cuda':
        model = model.to('cuda')

    return model
