import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from FF_MNIST import count_parameters


class MNIST_Dataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):
    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


class Net(torch.nn.Module):

    def __init__(self, dims, K, threshold=2.0):
        super().__init__()
        self.layers = []
        self.K = K
        for d in range(len(dims) - 1):
            self.layers += [Layer(dims[d], dims[d + 1], K=self.K, threshold=threshold).cuda()]
    def predict(self, x, y):
        acc = []
        count = 0
        layer_num = []
        inputs = x
        last_layer = None
        for layer in self.layers:
            h = layer(inputs)
            inputs = h.clone()
            goodness = torch.zeros(x.shape[0], self.K).cuda()
            last_layer = layer
            for k in range(self.K):
                goodness[:, k] += h[:, last_layer.support_index[k]].mean(1).reshape(-1)
            label = goodness.argmax(1)
            acc.append(label.eq(y).float().mean().item())
            layer_num.append(count)
            count += 1
        plt.figure(figsize=(5, 5))
        plt.plot(layer_num, acc, label='Accuracy', color='blue')
        plt.title('Acc Curve')
        plt.xlabel('Layer')
        plt.ylabel('Accuracy')
        output_path = f'PvN_acc_curve.png'
        plt.savefig(output_path)
        print(f"Accuracy curve saved to {output_path}")
        print(acc)

    def train(self, x, y):
        h = x
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')
            h = layer.train(h, y)

    def get_neuron(self, x):
        results = []
        input = x
        for layer in self.layers:
            h = layer(input)
            results.append(layer.get_neuron(input))
            input = h.clone()
        return results


class Layer(nn.Linear):
    def __init__(self, in_features, out_features, K, threshold=2.0, device="cuda",
                 bias=True, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.num_sup = out_features // K
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(K)]
        self.support_index_neg = [[i for i in list(range(out_features)) if i not in self.support_index[k]] for k in range(K)]
        self.opt = Adam(self.parameters(), lr=0.001)
        self.threshold = threshold
        self.num_epochs = 100
        self.device = device
        pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(K)])
        neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(K)])
        self.pos_all = pos_indices.to(self.device)
        self.neg_all = neg_indices.to(self.device)
        self.K = K


    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def train(self, x, y):
        self.dataset = MNIST_Dataset(x, y)
        self.dataloader = DataLoader(self.dataset, batch_size=64, shuffle=True, num_workers=0)
        for j in tqdm(range(self.num_epochs)):
            for inputs, labels in self.dataloader:
                loss = self.loss(inputs, labels, L='PvN')
                self.opt.zero_grad()
                # this backward just compute the derivative and hence
                # is not considered backpropagation.
                loss.backward()
                self.opt.step()
        return self.forward(x).detach()

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            # Compute loss
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) / (torch.exp(response).sum(1) + eps)).mean()

        elif L == "CE":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            loss = F.cross_entropy(response, labels)

        else:
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            # Prepare support indices
            sup_tensor = self.pos_all[labels].to(self.device)
            sup_neg_tensor = self.neg_all[labels].to(self.device)

            # Gather positive and negative samples
            g_pos = g_flat.gather(1, sup_tensor).mean(1)
            g_neg = g_flat.gather(1, sup_neg_tensor).mean(1)

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss

    def get_neuron(self, x):
        g = self.forward(x).pow(2)
        sup = [self.support_index[i] for i in y]
        sup_neg = [self.support_index_neg[i] for i in y]
        g_pos = g.gather(1, torch.tensor(sup).cuda()).mean(1)
        g_neg = g.gather(1, torch.tensor(sup_neg).cuda()).mean(1)
        return [sup, sup_neg, g_pos, g_neg]





def visualize_sample(data, name='', idx=0):
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize=(4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.show()



if __name__ == "__main__":
    epochs = 5
    for epoch in range(epochs):
        train_loader, test_loader = MNIST_loaders(train_batch_size=50000)

        net = Net([784, 100, 100], K=10)
        x, y = next(iter(train_loader))
        x, y = x.cuda(), y.cuda()


        net.train(x, y)
        net.predict(x, y)
        x_te, y_te = next(iter(test_loader))
        x_te, y_te = x_te.cuda(), y_te.cuda()
        net.predict(x_te, y_te)

