import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from matplotlib.colors import ListedColormap
from scipy.special import softmax


def get_loss_landscape_cond(model, theta, c_map, x, y, n, alpha):

    offset = torch.linspace(-alpha * 3, alpha * 3, n)
    offset_x, offset_y = torch.meshgrid(offset, offset, indexing='ij')
    perturbations = torch.stack([offset_x.flatten(), offset_y.flatten()], dim=1)
    all_cond = c_map + perturbations
    loss_landscape = []
    for i in tqdm(range(n ** 2)):
        c = all_cond[i:i + 1].repeat(x.shape[0], 1)
        loss_landscape.append(model.loss(model(x, theta, c), y)[1].detach().cpu().item())
    loss_landscape = np.array(loss_landscape).reshape(n, n)

    return loss_landscape, offset_x, offset_y

def plot_hypothesis_reg(all_x, py_opt, loader, name=None, color='tab:green', save=False):
    x_train, y_train = loader.dataset.x_train, loader.dataset.y_train
    x_train = x_train.detach().cpu().numpy()[:, 0]
    y_train = y_train.detach().cpu().numpy()[:, 0]
    plt.scatter(x_train, y_train, zorder=2)
    for y_sample in py_opt:
        plt.plot(all_x.detach().cpu().numpy()[:, 0], y_sample.detach().cpu().numpy()[:, 0], color=color, alpha=0.5, zorder=1)
    plt.ylim(-5, 5)
    plt.title(name) if name is not None else None
    if save:
        plt.savefig('./imgs/'+name+'.pdf')
        plt.close()
    else:
        plt.show()

def plot_var_reg(all_x, py_opt, loader, color='tab:green', name=None, save=False):
    y_mu = py_opt.mean(0).detach().cpu().numpy()[:, 0]
    y_std = py_opt.std(0).detach().cpu().numpy()[:, 0]
    x_train, y_train = loader.dataset.x_train, loader.dataset.y_train
    x_train = x_train.detach().cpu().numpy()[:, 0]
    y_train = y_train.detach().cpu().numpy()[:, 0]
    plt.scatter(x_train, y_train)
    plt.plot(all_x.detach().cpu().numpy()[:, 0], y_mu, linewidth=3.0, color=color)
    plt.fill_between(all_x.detach().cpu().numpy()[:, 0], y_mu - y_std, y_mu + y_std, alpha=0.3, color=color)
    plt.ylim(-5, 5)
    plt.title(name) if name is not None else None
    if save:
        plt.savefig('./imgs/'+name+'.pdf')
        plt.close()
    else:
        plt.show()


def plot_var_class(all_x, py_opt, loader, name=None, color='tab:green', save=False):

    x_test, y_test = loader.dataset.x_test.detach().cpu().numpy(), loader.dataset.y_test.detach().cpu().numpy()

    N_grid = 100
    offset = 2
    x1min = x_test[:, 0].min() - offset
    x1max = x_test[:, 0].max() + offset
    x2min = x_test[:, 1].min() - offset
    x2max = x_test[:, 1].max() + offset

    x_grid = np.linspace(x1min, x1max, N_grid)
    y_grid = np.linspace(x2min, x2max, N_grid)
    XX1, XX2 = np.meshgrid(x_grid, y_grid)
    X_grid = np.column_stack((XX1.ravel(), XX2.ravel()))

    py = torch.reshape(py_opt, (-1, N_grid, N_grid, 2)).detach().cpu().numpy()
    py_class = (py[:, :, :, 0] > py[:, :, :, 1]) * 1.
    unc = np.std(py_class, 0)

    contour_unc = plt.contourf(
        XX1,
        XX2,
        unc,
        alpha=0.8,
        antialiased=True,
        cmap="Blues",
        levels=np.arange(np.min(unc) - 0.1, np.max(unc) + 0.1, 0.1),
    )
    plt.colorbar(contour_unc, orientation='vertical')
    plt.scatter(x_test[:, 0], x_test[:, 1],
                    c=y_test, cmap=ListedColormap(["#f2cc8f", "tab:orange"]),
                    edgecolor='black', linewidth=0.2, s=15, zorder=1, alpha=1.0)
    plt.title("Model Uncertainty")
    plt.tight_layout()
    plt.title(name) if name is not None else None
    if save:
        plt.savefig('./imgs/' + name + '.pdf')
        #plt.savefig('./imgs/' + name + '.png')
        plt.close()
    else:
        plt.show()


def get_regression_fig(model, loader, device):

    all_x = torch.linspace(-1, 7, 100).float().to(device).view(-1, 1)
    y_map, y_mu, y_std, py = model.posterior(all_x, loader)

    x_train, y_train = loader.dataset.x_train.detach().cpu().numpy()[:, 0], loader.dataset.y_train.detach().cpu().numpy()[:, 0]
    x_test, y_test = loader.dataset.x_test.detach().cpu().numpy()[:, 0], loader.dataset.y_test.detach().cpu().numpy()[:, 0]
    all_x = all_x.detach().cpu().numpy()[:, 0]
    y_mu, y_std = y_mu.detach().cpu().numpy()[:, 0], y_std.detach().cpu().numpy()[:, 0]
    y_map = y_map.detach().cpu().numpy()[:, 0]

    fig = plt.figure()

    plt.scatter(x_train, y_train, color='tab:blue', label='D_train')
    plt.scatter(x_test, y_test, color='tab:red', label='D_test')
    plt.plot(all_x, y_map, color='tab:green', linewidth=3, label='MAP')
    plt.plot(all_x, y_mu, color='tab:orange', linewidth=3, label='Posterior Mean')
    plt.fill_between(all_x, y_mu - y_std, y_mu + y_std, alpha=0.3, color='tab:green')
    plt.plot(all_x, py.detach().cpu().numpy()[0, :, 0], color='tab:orange', alpha=0.1, label='Posterior Samples')
    for y_sample in py:
        plt.plot(all_x, y_sample.detach().cpu().numpy()[:, 0], color='tab:orange', alpha=0.1)
    plt.ylim(-4, 4)
    plt.tight_layout(rect=[0, 0.15, 1, 1])
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=3)  # Adjust legend position and columns

    return fig



def get_banana_fig(model, loader, device):

    x_test, y_test = loader.dataset.x_test.detach().cpu().numpy(), loader.dataset.y_test.detach().cpu().numpy()

    N_grid = 100
    offset = 2
    x1min = x_test[:, 0].min() - offset
    x1max = x_test[:, 0].max() + offset
    x2min = x_test[:, 1].min() - offset
    x2max = x_test[:, 1].max() + offset

    x_grid = np.linspace(x1min, x1max, N_grid)
    y_grid = np.linspace(x2min, x2max, N_grid)
    XX1, XX2 = np.meshgrid(x_grid, y_grid)
    X_grid = np.column_stack((XX1.ravel(), XX2.ravel()))
    all_x = torch.from_numpy(X_grid).float().to(device)

    y_map, y_mu, y_std, py = model.posterior(all_x, loader)

    y_map = torch.reshape(y_map, (N_grid, N_grid, 2)).detach().cpu().numpy()
    y_map = softmax(y_map, -1)
    py = torch.reshape(py, (-1, N_grid, N_grid, 2)).detach().cpu().numpy()
    py_class = (py[:, :, :, 0] > py[:, :, :, 1]) * 1.
    unc = np.std(py_class, 0)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    contour_like = axes[0].contourf(
        XX1,
        XX2,
        y_map[:, :, 0],
        alpha=0.8,
        antialiased=True,
        cmap="PuOr",
        levels=np.arange(-0.3, 1.3, 0.1),
    )
    plt.colorbar(contour_like, ax=axes[0], orientation='vertical')
    axes[0].scatter(x_test[:, 0], x_test[:, 1],
                    c=y_test, cmap=ListedColormap(["tab:purple", "tab:orange"]),
                    edgecolor='black', linewidth=0.01, s=15, zorder=1, alpha=1.0)
    axes[0].set_title("Likelihood Uncertainty")
    contour_unc = axes[1].contourf(
        XX1,
        XX2,
        unc,
        alpha=0.8,
        antialiased=True,
        cmap="Blues",
        levels=np.arange(np.min(unc) - 0.1, np.max(unc) + 0.1, 0.01),
    )
    plt.colorbar(contour_unc, ax=axes[1], orientation='vertical')
    axes[1].scatter(x_test[:, 0], x_test[:, 1],
                    c=y_test, cmap=ListedColormap(["tab:purple", "tab:orange"]),
                    edgecolor='black', linewidth=0.01, s=15, zorder=1, alpha=1.0)
    axes[1].set_title("Model Uncertainty")
    plt.tight_layout()

    return fig



















