import json

import torch # type: ignore
import torch.nn as nn
import torch.nn.functional as F

from difflogic import LogicLayer, GroupSum, PackBitsTensor, CompiledLogicNet

from library import configs
from library import misc

def input_dim_of_dataset(dataset: configs.Dataset, thresholds: int = None, config=None) -> int:
    if dataset == configs.Dataset.CIFAR10 or dataset == configs.Dataset.CIFAR100 or dataset == configs.Dataset.IMAGENET32 or dataset == configs.Dataset.CUSTOM_IMAGENET:
        assert thresholds is not None
    else:
        thresholds = 0
    
    if dataset == configs.Dataset.SYNTHETIC:
        try:
            return config.data_config.input_size
        except Exception as e:
            print(f"Warning: input_size not found, defaulting to 784. Exception: {e}")
            return 784
    if config.data_config.upscale_input != 0:
        return config.data_config.upscale_input * config.data_config.upscale_input
    return {
        configs.Dataset.ADULT: 116,
        configs.Dataset.BREAST_CANCER: 51,
        configs.Dataset.MONK1: 17,
        configs.Dataset.MONK2: 17,
        configs.Dataset.MONK3: 17,
        configs.Dataset.MNIST: 784, #6272,
        configs.Dataset.FMNIST: 784,
        configs.Dataset.KMNIST: 784,
        configs.Dataset.QMNIST: 784,
        configs.Dataset.EMNIST_BALANCED: 784,
        configs.Dataset.EMNIST_LETTERS: 784,
        configs.Dataset.CUSTOM: 784,
        configs.Dataset.SYNTHETIC: 784,
        configs.Dataset.MNIST20x20: 400,
        configs.Dataset.CIFAR10: 3 * 32 * 32 * thresholds,
        configs.Dataset.CIFAR100: 3 * 32 * 32 * thresholds,
        configs.Dataset.IMAGENET32: 3 * 32 * 32 * thresholds,
        configs.Dataset.CUSTOM_IMAGENET: 3 * 32 * 32 * thresholds,
    }[dataset]


def num_classes_of_dataset(data_config) -> int:
    return {
        configs.Dataset.ADULT: 2,
        configs.Dataset.BREAST_CANCER: 2,
        configs.Dataset.MONK1: 2,
        configs.Dataset.MONK2: 2,
        configs.Dataset.MONK3: 2,
        configs.Dataset.MNIST: 10,
        configs.Dataset.FMNIST: 10,
        configs.Dataset.KMNIST: 10,
        configs.Dataset.QMNIST: 10,
        configs.Dataset.EMNIST_BALANCED: 47,
        configs.Dataset.EMNIST_LETTERS: 26,
        configs.Dataset.CUSTOM: data_config.num_classes,
        configs.Dataset.CUSTOM_IMAGENET: data_config.num_classes,
        configs.Dataset.SYNTHETIC: data_config.num_classes,
        configs.Dataset.MNIST20x20: 10,
        configs.Dataset.CIFAR10: 10,
        configs.Dataset.CIFAR100: 100,
        configs.Dataset.IMAGENET32: 1000,
    }[data_config.dataset]

class SoftToHardANDLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.w = nn.Parameter(torch.randn(output_dim, input_dim) * 0.01)

    def forward(self, x):
        """
        x: (batch_size, input_dim)
        """
        batch_size = x.size(0)
        output_dim = self.w.size(0)

        # Preallocate output
        out = torch.empty(batch_size, output_dim, device=x.device)

        if self.training:
            for o in range(output_dim):
                # Vectorized per-output computation
                prod = x * self.w[o]                    # (batch_size, input_dim)
                out[:, o] = torch.sigmoid(prod).sum(dim=1)  # (batch_size,)
        else:
            w_hard = torch.round(self.w)
            for o in range(output_dim):
                prod = x * w_hard[o]
                out[:, o] = torch.sigmoid(prod).sum(dim=1)

        return out

class TreeNetwork(nn.Module):
    def __init__(self, input_size, n_layers):
        super().__init__()
        layers = []
        in_features = input_size

        for _ in range(n_layers):
            out_features = max(1, in_features // 2)  # Prevent zero-sized layers
            layers.append(LogicLayer(in_features, out_features))
            in_features = out_features

        self.layers = nn.Sequential(*layers)
        self.final_output_size = in_features

    def forward(self, x):
        x = self.layers(x)
        return x.sum(dim=1)  # Sum over the final neurons

class TreeClassificationLayer(nn.Module):
    # Custom classification layer with optional full input for each TreeNetwork.
    def __init__(self, total_input_size, k_classes, n_layers_per_class, tau=10, full_output=False):
        super().__init__()
        self.k = k_classes
        self.tau = tau
        self.full_output = full_output

        # Decide input size for each TreeNetwork based on `full_output`
        input_size_per_class = total_input_size if full_output else total_input_size // k_classes

        if not full_output:
            assert total_input_size % k_classes == 0, \
                "Input size must be divisible by number of classes when full_output=False"

        self.chunk_size = input_size_per_class  # only meaningful when full_output=False

        self.class_nets = nn.ModuleList([
            TreeNetwork(input_size_per_class, n_layers_per_class)
            for _ in range(k_classes)
        ])

    def forward(self, x):
        outputs = []

        for i in range(self.k):
            if self.full_output:
                x_i = x  # full input
            else:
                x_i = x[:, i * self.chunk_size:(i + 1) * self.chunk_size]

            out_i = self.class_nets[i](x_i)
            outputs.append(out_i.unsqueeze(1))  # shape: [batch_size, 1]

        return torch.cat(outputs, dim=1) / self.tau  # shape: [batch_size, k_classes]

class DistanceLayer(nn.Module):
    def __init__(self, input_dim, num_classes, code_dim=None, tau=1.0):
        super(DistanceLayer, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.code_dim = input_dim if code_dim is None else code_dim
        self.tau = tau

        # Store codebook as compact binary (uint8)
        codebook_init = torch.randint(0, 2, (num_classes, self.code_dim), dtype=torch.uint8)
        self.register_buffer("codebook_binary", codebook_init)

    def forward(self, x):
        # Convert binary codebook to float on the fly
        codebook = self.codebook_binary.to(dtype=torch.float32)
        dists = torch.cdist(x, codebook)
        return -dists / self.tau

class SplitReduceLayer(nn.Module):
    def __init__(self, l: int, reduce: str = 'mean'):  # or 'sum'
        super().__init__()
        assert 64000 % l == 0, "64000 must be divisible by l"
        self.l = l
        self.chunk_size = 64000 // l
        assert reduce in ['mean', 'sum'], "reduce must be 'mean' or 'sum'"
        self.reduce = reduce

    def forward(self, x):
        # x shape: [batch_size, 64000]
        x = x.view(x.size(0), self.l, self.chunk_size)  # [batch_size, l, chunk_size]
        if self.reduce == 'mean':
            return x.mean(dim=2)  # [batch_size, l]
        else:
            return x.sum(dim=2)   # [batch_size, l]

class DropoutGroupSum(torch.nn.Module):
    """
    The GroupSum module with a dropout-like mask.
    """
    def __init__(self, k: int, tau: float = 1., dropout_percentage=1.0, device='cuda'):
        """
        :param k: number of intended real valued outputs, e.g., number of classes
        :param tau: the (softmax) temperature tau. The summed outputs are divided by tau.
        :param device:
        """
        super().__init__()
        self.k = k
        self.tau = tau
        self.device = device
        self.dp = dropout_percentage

    def forward(self, x):
        if isinstance(x, PackBitsTensor):
            return x.group_sum(self.k)

        assert x.shape[-1] % self.k == 0, (x.shape, self.k)
        
        # Reshape input into (batch_size, num_classes, neurons_per_class)
        x_reshaped = x.reshape(*x.shape[:-1], self.k, x.shape[-1] // self.k)

        if self.training:
            # Generate dropout mask with 50% dropout
            dropout_mask = (torch.rand_like(x_reshaped) > self.dp).float()
            x_reshaped = x_reshaped * dropout_mask

        # Sum across the neurons per class
        output = x_reshaped.sum(-1) / self.tau
        
        return output

def create_model(config: configs.DifflogicConfig) -> torch.nn.Module:
    data_config = config.data_config
    model_config = config.model_config
    misc.set_seed(model_config.seed)

    llkw = dict(grad_factor=model_config.grad_factor, connections=model_config.connections)

    in_dim = input_dim_of_dataset(data_config.dataset, data_config.image_thresholds, config=config)
    print(f"In Dimension: {in_dim}")
    class_count = num_classes_of_dataset(data_config)

    logic_layers = []

    arch = model_config.architecture
    k = model_config.num_neurons
    l = model_config.num_layers
    assert l != 1, "Support for l == 1 deprecated. Use more than one layer: l >= 2"

    ####################################################################################################################
    if arch == 'randomly_connected':
        print('randomly_connected', in_dim, k, llkw, l)
        logic_layers.append(torch.nn.Flatten())
        logic_layers.append(LogicLayer(in_dim=in_dim, out_dim=k, **llkw))
        for _ in range(l - 1):
            logic_layers.append(LogicLayer(in_dim=k, out_dim=k, **llkw))

        if config.model_config.last_layer_neurons == 0:
            print(f"Warning: Outputlayer size not set, using model_config.num_neurons ({config.model_config.num_neurons}) instead.")
            config.model_config.last_layer_neurons = config.model_config.num_neurons
        logic_layers[-1] = LogicLayer(in_dim=k, out_dim=config.model_config.last_layer_neurons, **llkw)
        
        assert sum([model_config.use_groupsum, 
                    model_config.use_mygroupsum, 
                    model_config.tree_classification,
                    model_config.full_ffn,
                    model_config.distanceLayer,
                    model_config.distanceLayer2]) == 1, "Only one should be true at a time."

        if model_config.use_groupsum:
            logic_layers.append(GroupSum(class_count, model_config.tau))
        if model_config.use_mygroupsum:
            print(f"Warning: model_config.use_mygroupsum is now DropoutGroupSum")
            logic_layers.append(DropoutGroupSum(class_count, tau=model_config.tau, dropout_percentage=model_config.dropout_percentage))
        if model_config.tree_classification:
            assert model_config.tree_classification or model_config.use_groupsum
            logic_layers.append(TreeClassificationLayer(total_input_size=model_config.last_layer_neurons, 
                                                        k_classes=class_count, 
                                                        n_layers_per_class=model_config.tree_layers, 
                                                        tau=model_config.tau,
                                                        full_output=model_config.full_tree_output))
        # Use a ffn after using a DLGN
        if model_config.use_ffn:
            if model_config.use_groupsum or model_config.use_mygroupsum:
                logic_layers.pop()
            # First layer: from last_layer_neurons to first hidden layer
            layer_sizes = model_config.used_ffn
            logic_layers.append(torch.nn.Linear(config.model_config.last_layer_neurons, layer_sizes[0]))
            logic_layers.append(torch.nn.LeakyReLU(negative_slope=0.1))
            # Intermediate hidden layers
            for i in range(len(layer_sizes) - 1):
                logic_layers.append(torch.nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
                logic_layers.append(torch.nn.LeakyReLU(negative_slope=0.1))
        
            # Final layer: from last hidden layer to class_count
            logic_layers.append(torch.nn.Linear(layer_sizes[-1], class_count))

        # Full ffn for e.g. FFN Baseline
        if model_config.full_ffn:
            logic_layers = [nn.Flatten(),
                nn.Linear(in_dim, model_config.ffn_layer_size),
                nn.BatchNorm1d(model_config.ffn_layer_size),
                nn.ReLU(),
    
                nn.Linear(model_config.ffn_layer_size, model_config.ffn_layer_size),
                nn.BatchNorm1d(model_config.ffn_layer_size),
                nn.ReLU(),
    
                nn.Linear(model_config.ffn_layer_size, model_config.ffn_layer_size),
                nn.BatchNorm1d(model_config.ffn_layer_size),
                nn.ReLU(),
    
                nn.Linear(model_config.ffn_layer_size, class_count)]

        if model_config.distanceLayer:
            logic_layers.append(DistanceLayer(config.model_config.last_layer_neurons, class_count, tau=model_config.tau))
        if model_config.distanceLayer2:
            logic_layers.append(SplitReduceLayer(l=model_config.distance_dimension))
            logic_layers.append(DistanceLayer(model_config.distance_dimension, class_count, tau=model_config.tau))
    
        model = torch.nn.Sequential(*logic_layers)
    #################################################################################################################### 
    else:
        raise NotImplementedError(arch)

    ####################################################################################################################
    try:
        total_num_neurons = sum(map(lambda x: x.num_neurons, logic_layers[1:-1]))
        print(f'total_num_neurons={total_num_neurons}')
        total_num_weights = sum(map(lambda x: x.num_weights, logic_layers[1:-1]))
        print(f'total_num_weights={total_num_weights}')
    except:
        pass

    if model_config.device == 'cuda':
        model = model.to('cuda')

    return model
