import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from torch import nn


class Monitor:
    def __init__(self, device, base, name, saving_dir="./store/"):
        self.test_loss = []
        self.train_loss = []
        self.representation = []
        self.criterion = nn.CrossEntropyLoss(reduction="mean")
        self.saving_dir = saving_dir
        self.device = device
        self.base = base
        self.name = name
        os.makedirs(saving_dir + "figs", exist_ok=True)
        os.makedirs(saving_dir + "tensor", exist_ok=True)

    def record(self, test_dataloader, train_dataloader, model):
        inputs, labels = next(iter(test_dataloader))
        inputs, labels = inputs.to(self.device), labels.to(self.device)
        with torch.no_grad():
            inputs = self.base(inputs)
        model.eval()
        self.record_test_loss(inputs, labels, model)
        self.record_representation(inputs[0].unsqueeze(0), model)

        inputs_t, labels_t = next(iter(train_dataloader))
        inputs_t, labels_t = inputs_t.to(self.device), labels_t.to(self.device)
        with torch.no_grad():
            inputs_t = self.base(inputs_t)
        self.record_test_loss(inputs_t, labels_t, model, False)

    def record_test_loss(self, inputs, labels, model, is_test=True):
        with torch.no_grad():
            test_loss = self.criterion(model(inputs), labels)
        if is_test:
            self.test_loss.append(test_loss.item())
        else:
            self.train_loss.append(test_loss.item())

    def record_representation(self, inputs, model):
        with torch.no_grad():
            self.representation.append(model.feature(inputs))

    def calc_pairwise_similarity(self):
        n = len(self.representation)
        m = n // 10
        self.similarity = torch.zeros((10, 10))
        for i in range(10):
            for j in np.arange(i, 10):
                sim_ = F.cosine_similarity(
                    self.representation[i * m], self.representation[j * m]
                )
                self.similarity[i, j], self.similarity[j, i] = sim_, sim_

    def print_all(self):
        show = False  # True
        # plot loss
        loss = torch.tensor(self.test_loss).view(-1)
        sns.lineplot(x=torch.arange(len(loss)), y=loss)
        plt.title("")
        plt.xlabel("Time")
        plt.ylabel("Test loss")
        plt.ylim(bottom=0)
        plt.savefig(
            self.saving_dir + "figs/" + self.name + "_testloss.png",
            format="png",
            dpi=300,
        )
        if show:
            plt.show()
        else:
            plt.close()
        torch.save(loss, self.saving_dir + "tensor/" + self.name + "_loss.pt")

        loss = torch.tensor(self.train_loss).view(-1)
        sns.lineplot(x=torch.arange(len(loss)), y=loss)
        plt.title("")
        plt.xlabel("Time")
        plt.ylabel("Training loss")
        plt.ylim(bottom=0)
        plt.savefig(
            self.saving_dir + "figs/" + self.name + "_trainloss.png",
            format="png",
            dpi=300,
        )
        if show:
            plt.show()
        else:
            plt.close()
        torch.save(loss, self.saving_dir + "tensor/" + self.name + "_train_loss.pt")

        # plot representation
        timestamp = np.arange(0, len(self.representation), 50)
        repre = torch.concat(self.representation, 0)[timestamp, :3].cpu()
        for i in range(3):
            sns.lineplot(x=timestamp, y=repre[:, i], label=f"Embedding {i + 1}")

        plt.title("Embeddings on first three dimensions")
        plt.xlabel("Time")
        plt.ylabel("Value")
        plt.savefig(
            self.saving_dir + "figs/" + self.name + "_embedding.png",
            format="png",
            dpi=300,
        )
        if show:
            plt.show()
        else:
            plt.close()
        torch.save(repre, self.saving_dir + "tensor/repre.pt")

        # plot similarity matrix
        self.calc_pairwise_similarity()
        sns.heatmap(self.similarity, cmap="viridis_r")
        plt.title("Similarity Matrix")
        plt.savefig(
            self.saving_dir + "figs/" + self.name + "_similarity.png",
            format="png",
            dpi=300,
        )
        if show:
            plt.show()
        else:
            plt.close()
        torch.save(
            self.similarity, self.saving_dir + "tensor/" + self.name + "_similarity.pt"
        )

        representation = torch.cat(self.representation, 0).cpu()
        torch.save(
            representation[-int(0.7 * len(representation)) :],
            self.saving_dir + "tensor/" + self.name + "_representation.pt",
        )
        n = int(0.9 * representation.shape[0])
        cos_dist = torch.tensor(
            [
                F.cosine_similarity(self.representation[n], self.representation[i])
                for i in np.arange(n, representation.shape[0], 5)
            ]
        )
        torch.save(cos_dist, self.saving_dir + "tensor/" + self.name + "_cos_dist.pt")

        plt.plot(cos_dist)
        plt.title("Magnitude of Change Over Time")
        plt.xlabel("Time")
        plt.ylabel("Magnitude of Change")
        plt.savefig(
            self.saving_dir + "figs/" + self.name + "_change.png", format="png", dpi=300
        )
        if show:
            plt.show()
        else:
            plt.close()


def k_max_elements_to_1(tensor, k):
    # Sort each row in descending order to get values and indices
    _, indices = torch.sort(tensor, descending=True, dim=1)

    # Create a mask with all zeros
    mask = torch.zeros_like(tensor)

    # For each row, set the positions of the top k elements to 1
    for i in range(tensor.size(0)):
        mask[i, indices[i, :k]] = 1

    return mask
