import torch
import numpy as np
import matplotlib.pyplot as plt

def debug_plot(num_classes, data):
    scale_factor = 20
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(1, 1, figsize=(1 * scale_factor, 1 * scale_factor))

    axs.bar(np.arange(0, num_classes), data)
    axs.set_xticks(np.arange(0, num_classes))
    filename = f'debug.pdf'
    plt.tight_layout()
    plt.savefig(filename)
    plt.close(fig)

def normalize_class_weights(class_weights, model_confidences):
    max_confidences, predictions = torch.max(model_confidences, dim=1)
    num_classes = len(class_weights)
    num_samples = predictions.shape[0]

    weights_sum = 0
    for class_idx in range(num_classes):
        weights_sum += class_weights[class_idx] * torch.sum(predictions == class_idx).item()

    lam = num_samples / weights_sum
    print(f'weight sum {weights_sum} - lambda {lam}')

    normalized_class_weights = lam * class_weights
    return normalized_class_weights

def calculate_class_od_weights(model_confidences, bias_k, R, smoothing_factor):
    num_classes = model_confidences.shape[1]
    max_confidences, predictions = torch.max(model_confidences, dim=1)
    class_biases = torch.zeros(num_classes)
    confidences_sort_idcs = torch.argsort(max_confidences, descending=False)

    print(f'Calculating class biases on {bias_k} datapoints - Smoothing {smoothing_factor}')
    bias_relevant_predictions = predictions[confidences_sort_idcs[:bias_k]]
    for class_idx in range(num_classes):
        class_biases[class_idx] = torch.sum(bias_relevant_predictions == class_idx).item() / bias_k

    class_biases = (1 - smoothing_factor) * class_biases + smoothing_factor * (1/num_classes) * torch.ones_like(class_biases, dtype=torch.float)

    num_samples = predictions.shape[0]
    original_class_densities = torch.zeros(num_classes, dtype=torch.float)

    for class_idx in range(num_classes):
        original_class_densities[class_idx] = torch.sum(predictions == class_idx).item() / num_samples

    print(f'Calculating class weights on {num_samples} datapoints')
    smoothed_class_distribution = (1 - smoothing_factor) * original_class_densities \
                                  + smoothing_factor * (1 / num_classes) * torch.ones_like(
        original_class_densities, dtype=torch.float)

    bias_cleared_class_densities = torch.zeros(num_classes)
    for class_idx in range(num_classes):
        bias_cleared_class_densities[class_idx] = smoothed_class_distribution[class_idx] \
                                                  - (1 - R) * class_biases[class_idx]

    fill_value = torch.min(bias_cleared_class_densities[bias_cleared_class_densities > (1/num_classes)**2])
    bias_cleared_class_densities[bias_cleared_class_densities <=  (1/num_classes)**2] = fill_value
    bias_cleared_class_densities = bias_cleared_class_densities / torch.sum(bias_cleared_class_densities)

    class_weights = 1 - (R * bias_cleared_class_densities) / ((1 - R) * class_biases + R * bias_cleared_class_densities)

    weights_sum = 0
    for class_idx in range(num_classes):
        weights_sum += class_weights[class_idx] * torch.sum(predictions == class_idx).item()

    lam = num_samples / weights_sum
    print(f'weight sum {weights_sum} - lambda {lam}')

    class_weights = lam * class_weights

    debug_plot(num_classes, class_weights)
    return class_weights