import torch
import numpy as np

from library import configs
from library import models

def get_accuracy(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: configs.DifflogicConfig, train_mode: bool=True) -> float:
    total_samples = 0
    total_correct = 0
    device = config.model_config.device

    if train_mode is True:
        model.train()
    else:
        model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            _, predicted = logits.max(1)

            total_samples += y.size(0)
            total_correct += (predicted == y).sum().item()

    # print(f'Accuracy: {total_correct / total_samples:.2f}')
    return total_correct / total_samples

def get_accuracy_no_groupsum(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: configs.DifflogicConfig, train_mode: bool) -> float:
    device = config.data_config.device
    if loader is None:
        return -1

    orig_mode = model.training
    model.train(mode=train_mode)
    with torch.no_grad():
        accs = []
        weights = []
        for x, y in loader:
            y = y.to(device=device)
            x = x.to(device=device)
            # x = model(x)
            for i in range(len(model)):
                x = model[i](x)
            acc = accuracy_no_groupsum(x, y, config.train_config)
            accs.append(acc)
            weights.append(len(y))
        # res = np.mean(accs)
        res = np.average(accs, weights=weights)
        model.train(mode=orig_mode)
    
    return res.item()

def accuracy_no_groupsum(y_pred: torch.Tensor, y: torch.Tensor, train_config) -> float:
    # y_pred = (y_pred >= 0.5).float()
    num_classes = train_config.num_classes
    block_size = y.shape[1] // num_classes

    pred_blocks = y_pred.view(-1, num_classes, block_size).sum(dim=2)
    pred_class = pred_blocks.argmax(dim=1)

    # Should be a block full of 1s
    true_blocks = y.view(-1, num_classes, block_size).sum(dim=2)
    true_class = true_blocks.argmax(dim=1)

    return (pred_class == true_class).float().mean().item()

def store_logit_stats(label, output, num_classes, num_logits): # Save the output values of individual samples

    """
    hist, bin_edges = np.histogram(output[5 * 6400:6*6400], bins=10)
    
    for i in range(len(hist)):
        print(f"{bin_edges[i]:.1f} - {bin_edges[i+1]:.1f}: {hist[i]}")
    """

    logit_stats = np.zeros((num_logits, 4))
        
    true_labels = [0] * len(output)
    start_index = label * len(output)/num_classes
    end_index = start_index + len(output)/num_classes
    for i in range(int(start_index), int(end_index)):
        true_labels[i] = 1

    for i in range(len(output)):
        if true_labels[i] >= 0.5 and output[i] >= 0.5:
            logit_stats[i][0] += 1  # True Positive (TP)
        elif true_labels[i] >= 0.5 and output[i] < 0.5:
            logit_stats[i][3] += 1  # False Negative (FN)
        elif true_labels[i] < 0.5 and output[i] < 0.5:
            logit_stats[i][2] += 1  # True Negative (TN)
        elif true_labels[i] < 0.5 and output[i] >= 0.5:
            logit_stats[i][1] += 1  # False Positive (FP)

    return logit_stats

def get_logit_stats(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: configs.DifflogicConfig, train_mode: bool=True, num_logits=None) -> float:
    # Returns a list of TP, FP, TN, FN values for each 'output' logit.
    total_samples = 0
    total_correct = 0
    device = config.model_config.device

    logit_stats = np.zeros((num_logits, 4))

    if train_mode is True:
        model.train()
    else:
        model.eval()
    with torch.no_grad():
        for x, y in loader:
            y = y.to(device=device)
            x = x.to(device=device)
            # x = model(x)
            for i in range(len(model)):
                x = model[i](x)
                # Tap into second to last layer to get output logit stats
                
                if i == len(model) - 2:
                    for i in range(y.size(0)):
                        label = y[i].item()
                        output = x[i].tolist()
                        logit_stats += store_logit_stats(label, output, num_classes=models.num_classes_of_dataset(config.data_config.dataset), num_logits=num_logits)

    return logit_stats

def get_accuracy_for_logits(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: configs.DifflogicConfig, z: int, train_mode: bool=True) -> float:
    # Only consider the first z logits per class
    total_samples = 0
    total_correct = 0
    device = config.model_config.device

    if train_mode is True:
        model.train()
    else:
        model.eval()

    # assert train_mode == False

    # print(model)

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            for i in range(7):
                x = model[i](x)

            assert z <= config.model_config.num_neurons/models.num_classes_of_dataset(config.data_config.dataset)

            # Perform GroupSum for z logits per class
            logits = x.reshape(10000, 10, 6400)[:, :, :z].sum(dim=2)
            # logits = model(x)
            _, predicted = logits.max(1)

            total_samples += y.size(0)
            total_correct += (predicted == y).sum().item()

    # print(f'Accuracy: {total_correct / total_samples:.2f}')
    return total_correct / total_samples

def masked_accuracy(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: configs.DifflogicConfig, train_mode: bool=True, mask=None) -> float:
    # Returns a list of TP, FP, TN, FN values for each 'output' logit.
    total_samples = 0
    total_correct = 0
    device = config.model_config.device

    if train_mode is True:
        model.train()
    else:
        model.eval()
    with torch.no_grad():
        for x, y in loader:
            y = y.to(device=device)
            x = x.to(device=device)
            # x = model(x)
            for i in range(len(model)):
                x = model[i](x)
                # Tap into second to last layer to get output logit stats
                
                if i == len(model) - 2:
                    x = x * mask.to(x.device)
            logits = x
            _, predicted = logits.max(1)
            total_samples += y.size(0)
            total_correct += (predicted == y).sum().item()

    return total_correct / total_samples