from typing import Generator, Dict

import numpy as np
import torch.nn
import matplotlib.pyplot as plt

from .analysis_method import _SingleAnalysisMethod
from path_learning.utils.result import TaskResult


def plot_decision_boundary(logdir, xx, yy, X, Y, Z, uid, class_str):
    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.contourf(xx, yy, Z, cmap="coolwarm")
    plt.axis('off')

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=Y, cmap="Greens")
    plt.title("Decision boundary visualization")
    plt.xlabel("x-axis")
    plt.ylabel("y-axis")
    plt.savefig(logdir / f"decision_boundary_plot_cl{class_str}_{uid}.pdf")
    plt.close()


class DecisionBoundaryAnalysis(_SingleAnalysisMethod):

    name = "decision_boundary_plotting"

    def __init__(self, *args, **kwargs):
        super().__init__(*args)

    def analyze_model(self, task_result: TaskResult, model: torch.nn.Module) -> None:
        logdir = self.logdir / "tmp"
        logdir.mkdir(parents=True, exist_ok=True)
        self.logger.info(f"Creating decision boundary plot for toy example")
        dataloader = self.generate_dataloader(task_result)
        uid = str(task_result.uid)
        X = np.zeros((0, 2))
        Y = np.zeros((0,))
        for batch_idx, (feature, target) in enumerate(dataloader):
            X = np.append(X, feature, axis=0)
            Y = np.append(Y, target, axis=0)

        h = 0.1
        x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
        y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                             np.arange(y_min, y_max, h))

        # here "model" is your model's prediction (classification) function
        input_data = np.c_[xx.ravel(), yy.ravel()]
        Z = model(torch.from_numpy(input_data).float())
        Z0 = torch.nn.functional.softmax(Z, dim=1)[:, 0].cpu().detach().numpy()
        # Plotting
        plot_decision_boundary(logdir, xx, yy, X, Y, Z0, uid, 'cl0')




