import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns

from mpl_toolkits.basemap import Basemap

import torch

# color combination is perceived with equal contrast by people with and without color-blindness alike 
# according to https://davidmathlogic.com/colorblind/#%23D81B60-%231E88E5-%23FFC107-%23004D40
correct_color = "#005AB5" 
wrong_color = "#DC3220"

def evaluate_classification(model, trainer):
    outputs = trainer.predict(model, dataloaders=trainer.datamodule.test_dataloader())
    logits, coordinates, labels = list(zip(*outputs))

    logits, coordinates, labels = torch.concatenate(logits, axis=1), torch.vstack(coordinates), torch.hstack(labels)
    
    # assume that the second dimension is the batch dimension
    # assume that the first dimension is the dimension of jittering
    # assume that the third dimension is the dimension of the classes
    jitter_mode = logits.mode(0).values
    predictions = jitter_mode.argmax(-1)
    
    # compute a majority vote over the jittered predictions
    # predictions = jittered_predictions.mode(0).values

    # breakpoint()

    correct = (predictions == labels)

    return predictions, coordinates, labels, correct

def accuracy_bins(coords, correct, mn_val, mx_val, N_bins):
    bin_edges = np.linspace(mn_val, mx_val, N_bins+1)
    bin_width = np.diff(bin_edges)[0]
    bins_correct, _ = np.histogram(coords, bins=bin_edges, weights=correct.to(int))
    bins_total, _ = np.histogram(coords, bins=bin_edges)

    bins_accuracy = bins_correct / bins_total

    return bins_accuracy, bins_total, bin_edges, bin_width

def draw_bins(bin_vals, bin_edges, bin_width, x_axis_label, y_axis_label, plot_title, ax):
    ax.barh(bin_edges[:-1] + bin_width/2, bin_vals, height=bin_width * 0.9, align='center')
    ax.set_xlabel(x_axis_label)
    ax.set_ylabel(y_axis_label)
    ax.set_yticks(bin_edges[:-1] + bin_width/2)

    ax.set_title(plot_title)

def plot_classification_map(ax, plot_points, correct):
    title="Mapped Classification Results - After Forgetting Time Coordinate"
    bds=[-180,-90,180,90]

    map = Basemap(*bds)

    if plot_points is not None:
        map.scatter(plot_points[:,0][correct], plot_points[:,1][correct], c=correct_color, alpha=.8, s=.5, label="Correct")
        map.scatter(plot_points[:,0][~correct], plot_points[:,1][~correct], c=wrong_color, alpha=.8, s=.5, label="Incorrect")

    ax.set_xlabel("longitude")
    ax.set_ylabel("latitude")

    ax.set_title(title)
    ax.legend(loc=(1.1, 0.8))

    return ax

def plot_accuracy_per_timestep(ax, correct, times, train_fraction, test_fraction, num_timesteps, mode):
    palette = sns.color_palette("colorblind", 3)
    test_prediction = (times+1) / 2 > test_fraction
    training_prediction = (times+1) / 2 <= train_fraction

    forecast_accuracy = sum(correct[test_prediction]) / sum(test_prediction)
    training_accuracy = sum(correct[training_prediction]) / sum(training_prediction)

    if mode == "forecast_uniform":
        print("Test Accuracy on Forecast Timesteps (unseen):", forecast_accuracy)
        print("Test Accuracy on Training Timesteps (seen):", training_accuracy)

    N_bins = num_timesteps
    N_ticks = 10

    bin_edges = np.linspace(-1, 1, N_bins+1)
    bin_width = np.diff(bin_edges)[0]
    hist_correct, _ = np.histogram(times, bins=bin_edges, weights=correct.to(int))
    hist_total, _ = np.histogram(times, bins=bin_edges)

    hist_accuracy = hist_correct / hist_total

    title = "Accuracy by Time"

    if mode == "forecast_uniform":
        ax.text(x=0, y=1.4, s= f"Test Accuracy on Training Timesteps (seen): {training_accuracy:.3f};    Test Accuracy on Forecast Timesteps (unseen): {forecast_accuracy:.3f}")

    # ax.bar(bin_edges[:-1] + bin_width/2, hist_accuracy, height=np.diff(bin_edges), align='center', color=palette[0])
    ax.bar(x=bin_edges[:-1] + bin_width/2, width=np.diff(bin_edges), height=hist_accuracy, align='center', color=palette[0])
    if mode == "forecast_uniform":
        ax.plot([2*train_fraction - 1]*2, np.linspace(0, 1, num=2), c=palette[1], label="train fraction")
        ax.plot([2*test_fraction - 1]*2, np.linspace(0, 1, num=2), c=palette[2], label="test fraction")

    ax.set_xlabel("time")
    ax.set_ylabel("mean accuracy")

    ax.set_xticks(bin_edges[:-1 : N_bins // N_ticks] + N_ticks*bin_width/2)

    ax.legend(loc=(1.1, 0.9))
    plt.bar
    plt.title(title)
    return ax