import sys
sys.path.append("../../src")
import acquire
import math

import random
import numpy
import torch
from matplotlib import pyplot, gridspec
from scipy.stats import multivariate_normal

BEAM_FIGURES = True

seed = 5
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

pyplot.rcParams["font.family"] = "Liberation Serif"

N = 200

Xu = torch.rand(N, 2).float() * 2 - 1
Xl = torch.randn(2, 2) * 0.1

beam_width = 15
budget = 5
batch_size = 100
device = "cpu"
resolution = 50
repeats = 3
warp_target = (-1, 0)
if not BEAM_FIGURES:
    mesh_x = mesh_y = numpy.linspace(-1, 1, resolution)
    xv, yv = numpy.meshgrid(mesh_x, mesh_y)
    space = torch.from_numpy(
        numpy.stack([xv.flatten(), yv.flatten()], axis=1)
    )
    assert space.size() == (resolution**2, 2)



    _, axes = pyplot.subplots(nrows=5, ncols=repeats+1, sharex=True, sharey=True, figsize=(9, 12))
    axes[0, 0].set_ylabel("Core-set, greedy")
    axes[1, 0].set_ylabel("Core-set, inconfident regions")
    axes[2, 0].set_ylabel("Confidence")
    axes[3, 0].set_ylabel("Perspective from origin")
    axes[4, 0].set_ylabel(f"Perspective from {warp_target}")

    for i in range(5):
        axes[i, 0].set_yticks([-1, 0, 1])

else:
    fig = pyplot.figure(figsize=(12, 3))
    gs = gridspec.GridSpec(1, 5)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0, sharey=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0, sharey=ax0)
    ax3 = fig.add_subplot(gs[3], sharex=ax0, sharey=ax0)
    ax4 = fig.add_subplot(gs[4])

colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"]

def plot(pyplot, X, oX, fmt, label, color_quadrants=False):
    x = X[:, 0]
    y = X[:, 1]
    ox = oX[:, 0]
    oy = oX[:, 1]
    quad = (ox >= 0).astype(int) + (oy < 0).astype(int) * 2
    if not color_quadrants:
        rest_x, rest_y, rest_quad = x[budget:], y[budget:], quad[budget:]
        start_x, start_y, start_quad = x[:budget], y[:budget], quad[:budget]
        for i in range(4):
            rest_s = rest_quad == i
            start_s = start_quad == i
            c = colors[i]
            pyplot.scatter(rest_x[rest_s], rest_y[rest_s], alpha=0.2, s=20, marker=fmt, label=label, color=c)
            pyplot.scatter(start_x[start_s], start_y[start_s], alpha=1, s=20, marker=fmt, label=label, color=c)
    else:
        for i in range(4):
            s = quad == i
            pyplot.scatter(x[s], y[s], s=5, alpha=0.1, marker=fmt, label=label, color=colors[i])



def visualize(ax, Xu, oXu, Xl, oXl, i, name, warp):
    Xu = Xu.numpy()
    Xl = Xl.numpy()
    oXu = oXu.numpy()
    oXl = oXl.numpy()

    axi = axes[ax, i]
    plot(axi, Xu, oXu, ".", "unlabelled", color_quadrants=True)
    plot(axi, Xl, oXl, "s", "labelled", color_quadrants=warp)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self._corners = torch.Tensor([
            [-1, 1], [1, 1], [-1, -1], [1, -1]
        ])

        self._mvns = self.create_mvns(self._corners)
        self._labels = [0, 1, 2, 3]
        self._weights = [1, 1, 1, 1]

    def train_model(self, datapool):
        for _, X, y in datapool:
            self._mvns.append(self.create_mvn(X))
            self._labels.append(y)

            dists = (self._corners - X.unsqueeze(0)).norm(p=2, dim=1)
            min_dist = dists.min()
            self._weights.append(random.random() * (1-min_dist/math.sqrt(2)))

    def create_mvns(self, means):
        return [
            self.create_mvn(mean)
            for mean in means
        ]

    def create_mvn(self, mean):
        return multivariate_normal(mean=mean, cov=0.1)

    def forward(self, z):
        return self.classify_hidden_features(self.extract_features(z))

    def extract_features(self, z):
        return z

    def classify_hidden_features(self, z):
        pdfs = [self._predict(mvn, z) for mvn in self._mvns]
        output = torch.zeros(len(z), 4)
        for i, label in enumerate(self._labels):
            output[:, label] += pdfs[i] * self._weights[i]
        return output

    def _predict(self, mvn, z):
        pdf = mvn.pdf(z.numpy())
        return torch.from_numpy(pdf).float()


class Demonstration:

    def __init__(self, Xu, Xl, coreset, name, warp=None):
        self._Xu = Xu
        self._Xl = Xl
        self._coreset = coreset
        self._id = 0
        self._name = name
        self._warp = warp
        self._coreset._model.train_model(Datapool(self._Xl))
        self.visualize()

    def update(self):
        select = self._coreset.acquire_unlabelled_indices(
            self._coreset._budget,
            Datapool(self._Xu),
            Datapool(self._Xl)
        )
        if self._coreset.get_name().startswith("lc-beam"):
            self.visualize_beams(self._Xu, **self._coreset._plotting_data)

        mask = torch.full((len(self._Xu),), fill_value=False, dtype=torch.bool)
        mask[select] = True
        #print(self._Xu[mask])
        self._Xl = torch.cat([self._Xu[mask], self._Xl], dim=0)
        self._Xu = self._Xu[~mask]
        self.visualize()

        self._coreset._model.train_model(Datapool(self._Xl[:budget]))
        self._id += 1

    def visualize_beams(self, X, indices, beam_scores):
        if BEAM_FIGURES:
            all_x = X.numpy()
            indices = indices[::3][:4]
            beam_scores = beam_scores[::3][:4]

            ax4.cla()
            x = numpy.arange(1, 5)
            ax4.plot(x, beam_scores.numpy(), ".:")
            ax4.yaxis.set_label_position("right")
            ax4.yaxis.tick_right()
            ax4.set_ylabel("Negative log confidence")
            ax4.set_xlabel("Beam rank")
            ax4.set_xticks(x)
            ax4.grid(True, alpha=0.2)

            for ax in [ax0, ax1, ax2, ax3]:
                ax.cla()

            for i, ax in enumerate([ax0, ax1, ax2, ax3], 1):

                ax.set_xlabel("$x_2$")
                ax.set_title(f"Beam rank {i}")
                ax.grid(True, alpha=0.2)

                mask = torch.ones(len(X), dtype=torch.bool)
                mask[indices[i-1]] = 0
                xi = X[mask].numpy()
                #plot(ax, all_x, all_x, ".", "labelled", True)
                plot(ax, xi, xi, "s", "unlabelled", False)

            ax0.set_ylabel("$x_1$")

            pyplot.savefig(f"results/beam{self._id}.png", bbox_inches="tight")


    def visualize(self):
        if not BEAM_FIGURES:
            for i, warp in enumerate(self._warp):
                u = len(self._Xu)
                X = torch.cat([self._Xu, self._Xl], dim=0)
                wX = self.warp_x(X, warp)
                wXu = wX[:u]
                wXl = wX[u:]
                visualize(3+i, wXu, self._Xu, wXl, self._Xl, self._id, f"{self._name}-warped", False)

                if not i:
                    logits = self._coreset._model(space)
                    assert logits.size() == (resolution**2, 4)
                    conf, _ = torch.log_softmax(logits, dim=1).max(dim=1)
                    visualize_confidence_heat(2, self._id, mesh_x, mesh_y, conf.view(resolution, resolution), f"{self._name}-warped-{self._id}")

            visualize(int(bool(self._warp)), self._Xu, self._Xu, self._Xl, self._Xl, self._id, self._name, False)

    def warp_x(self, X, warp):
        warp = torch.Tensor(warp).unsqueeze(0)
        logits = self._coreset._model(X)
        conf, _ = torch.softmax(logits, dim=1).max(dim=1)
        return (X - warp) * (1-conf.unsqueeze(-1)) + warp

class Datapool(torch.utils.data.Dataset):

    def __init__(self, X):
        super().__init__()
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        y = self.determine_label(self.X[i])
        return i, self.X[i], y

    def determine_label(self, z):
        x, y = z
        lr = int(x >= 0)
        tb = int(y < 0) * 2
        return lr + tb

def visualize_confidence_heat(ax, i, x, y, conf, name):
    z = conf.numpy()
    axes[ax, i].pcolormesh(x, y, z, cmap="hot")

def main():

    model = Model()

    gcoreset = acquire.GreedyCoreset(
        budget=budget,
        batch_size=batch_size,
        device=device
    ).init(model)

    lc_beam_pweighted_coreset = acquire.LcBeamPWeightedCoreset(
        budget=budget,
        batch_size=batch_size,
        device=device,
        beam_width=beam_width
    ).init(model)

    d1 = Demonstration(Xu, Xl, gcoreset, "greedy-coreset", warp=[])
    d2 = Demonstration(Xu, Xl, lc_beam_pweighted_coreset, "lc-beam-pweighted-coreset", warp=[[0, 0], warp_target])

    for i in range(repeats):
        d1.update()
        d2.update()

    pyplot.savefig("results/visual.png", bbox_inches="tight")


if __name__ == "__main__":
    main()