import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torchvision

def Cifar_loaders(train_batch_size=50000, test_batch_size=10000):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=50000,
        shuffle=True,
        num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=10000,
        shuffle=False,
        num_workers=2
    )
    return trainloader, testloader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


def get_subset(index, subindexs):
    index = set(index)
    result = []
    for subindex in subindexs:
        subindex = set(subindex)
        result.append(sorted(index - subindex))
    return result

class Cifar_Dataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


class Net(torch.nn.Module):
    def __init__(self, dims, ismaxpool, K, kernels, strides, threshold=2.0, device='cuda'):
        super().__init__()
        layers = []
        self.K = K
        self.dims = dims
        self.layers_name = []
        self.device = device
        self.ismaxpool = ismaxpool
        for d in range(len(dims) - 1):
            layers += [Layer(dims[d], dims[d + 1], maxpool=self.ismaxpool[d], K=self.K, kernel_size=kernels[d],
                             stride=strides[d], threshold=threshold, device=device)]
            self.layers_name.append('conv')
        # Classifier
        dummy = torch.zeros(1, 3, 32, 32).to(device)
        for l in layers:
            dummy = l(dummy)
            if l.ismaxpool == True:
                dummy = l.maxpool(dummy)
        # in_features = dummy.view(1, -1).size(1)
        # layers += [Linear_Layer(in_features, 20, K=self.K).to(device)]
        # self.layers_name.append('linear')
        self.layers = nn.ModuleList(layers)
        self.kernels = kernels
        self.strides = strides

    def _compute_goodness(self, h, layer, labels=None):
        """Compute goodness score for conv or linear layer."""
        if isinstance(layer, Linear_Layer):
            h = h.view(h.size(0), -1)
            g = h.pow(2)
            goodness = torch.stack([
                g[:, layer.support_index[k]].mean(dim=1) for k in range(self.K)
            ], dim=1)
        else:  # Conv layer
            g = h.pow(2)
            goodness = torch.stack([
                g[:, layer.support_index[k], :, :].mean(dim=(1, 2, 3))
                for k in range(self.K)
            ], dim=1)
        return goodness.argmax(dim=1)

    def predict(self, x, y):
        """Evaluate accuracy layer by layer."""
        dataset = Cifar_Dataset(x, y)

        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

        layer_acc = torch.zeros(len(self.layers)).cuda()
        previous_layers = []

        with torch.no_grad():
            for layer_idx, layer in enumerate(self.layers):

                correct, total = 0, 0
                for inputs, labels in dataloader:
                    inputs, labels = inputs.cuda(), labels.cuda()
                    h = inputs

                    for pre_layer in previous_layers:
                        h = pre_layer(h)
                        if not isinstance(pre_layer, Linear_Layer):
                            if pre_layer.ismaxpool == True:
                                h = pre_layer.maxpool(h)
                            h = pre_layer.conv_bn(h)

                    h = layer(h)
                    preds = self._compute_goodness(h, layer)
                    correct += preds.eq(labels).sum().item()
                    total += labels.size(0)

                layer_acc[layer_idx] = correct / total
                previous_layers.append(layer)

        self._plot_acc_curve(layer_acc.cpu().numpy())
        return layer_acc

    def train(self, x, y):
        h = x
        dataset = Cifar_Dataset(h, y)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0)
        previous_layers = []
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')

            for j in tqdm(range(layer.num_epochs)):
                for inputs, labels in dataloader:
                    inputs, labels = inputs.to(self.device), labels.to(device)
                    h = inputs

                    with torch.no_grad():
                        for pre_layer in previous_layers:
                            h = pre_layer(h.detach())
                            if not isinstance(pre_layer, Linear_Layer):
                                if pre_layer.ismaxpool == True:
                                    h = pre_layer.maxpool(h)
                                h = pre_layer.conv_bn(h)
                    # if isinstance(layer, Layer):
                    #     loss = layer.loss(h, labels, L='CWC')
                    # else:
                    #     loss = layer.loss(h, labels, L='PvN')
                    loss = layer.loss(h, labels, L='CWC')

                    # Optimization step
                    layer.opt.zero_grad()
                    loss.backward()
                    layer.opt.step()
            torch.cuda.empty_cache()
            previous_layers.append(layer)

    def save(self, path):
        torch.save({
            'state_dict': self.state_dict(),
            'config': {
                'dims': self.dims,
                'K': self.K,
                'kernels': self.kernels,
                'strides': self.strides,
                'threshold': 2.0
            }
        }, path)
        print(f"Model saved to {path}")

    def load(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['state_dict'])
        print(f"Model loaded from {path}")

    def _plot_acc_curve(self, accuracies):
        plt.figure(figsize=(5, 5))
        plt.plot(range(len(self.layers)), accuracies, label='Accuracy', color='blue')
        plt.title('Accuracy by Layer')
        plt.xlabel('Layer Index')
        plt.ylabel('Accuracy')
        plt.ylim(0, 1.0)

        output_path = 'conv_acc_curve.png'
        plt.savefig(output_path)
        plt.close()
        print(f"Accuracy curve saved to {output_path}")
        print("Layer accuracies:", [f"{acc:.4f}" for acc in accuracies])


class Layer(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, K, stride, padding=1,
                 threshold=2.0, maxpool=False, groups=1, droprate=0, device="cuda", bias=True, dtype=None):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, bias=bias, device=device, dtype=dtype)
        self.relu = torch.nn.ReLU()
        self.K = K
        self.stride = stride
        self.padding = padding
        self.threshold = threshold
        self.num_epochs = 100
        self.opt = Adam(self.parameters(), lr=0.001)
        self.device = device
        self.ismaxpool = maxpool
        self.groups = groups
        self.conv_bn = nn.BatchNorm2d(out_channels, eps=1e-4).to(device)
        self.dropout = torch.nn.Dropout(p=droprate)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2).to(device)

        # Will be initialized during first forward pass
        self.output_features = out_channels
        self.num_sup = self.output_features // K
        self.index = list(range(0, self.num_sup * K))
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(self.K)]
        self.support_index_neg = get_subset(self.index, self.support_index)

        self.pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(self.K)])
        self.neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(self.K)])
        self.pos_indices = self.pos_indices.to(self.device)
        self.neg_indices = self.neg_indices.to(self.device)

    def forward(self, x):
        # Compute normalized direction
        x = x.to(self.weight.device)
        # Forward Pass
        x = self._conv_forward(x, self.weight, self.bias)
        x = F.relu(x, inplace=True)
        # if self.ismaxpool:
        #     x = self.maxpool(x)
        # y = self.conv_bn(x)
        return x

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                response[:, i] += g[:, ind, :, :].mean((1, 2, 3))
            response = torch.clamp(response, min=-50, max=50)
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) / (torch.exp(response).sum(1) + eps)).mean()

        elif L == "CE":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                response[:, i] += g[:, ind, :, :].mean((1, 2, 3))
            loss = F.cross_entropy(response, labels)

        else:
            g = self.forward(inputs).pow(2)
            sup_channel = self.pos_indices[labels]
            sup_neg_channel = self.neg_indices[labels]
            sup_channel = sup_channel[:, :, None, None].expand(-1, -1, g.shape[-2], g.shape[-1])
            sup_neg_channel = sup_neg_channel[:, :, None, None].expand(-1, -1, g.shape[-2], g.shape[-1])

            g_pos = g.gather(1, sup_channel).mean((1, 2, 3))
            g_neg = g.gather(1, sup_neg_channel).mean((1, 2, 3))

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss


class Linear_Layer(nn.Linear):
    def __init__(self, in_features, out_features, K, threshold=2.0, device="cuda",
                 bias=True, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.num_sup = out_features // K
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(K)]
        self.support_index_neg = [[i for i in list(range(out_features)) if i not in self.support_index[k]] for k in
                                  range(K)]
        self.opt = Adam(self.parameters(), lr=0.001)
        self.threshold = threshold
        self.num_epochs = 100
        self.device = device
        pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(K)])
        neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(K)])
        self.pos_all = pos_indices.to(self.device)
        self.neg_all = neg_indices.to(self.device)
        self.K = K

    def forward(self, x):
        x = x.to(self.weight.device)
        if x.ndim == 4:
            batchsize, channel, height, width = x.shape
            x = x.reshape(batchsize, -1)
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            # Compute loss
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) / (torch.exp(response).sum(1) + eps)).mean()

        elif L == "CE":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            loss = F.cross_entropy(response, labels)

        else:
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            # Prepare support indices
            sup_tensor = self.pos_all[labels].to(self.device)
            sup_neg_tensor = self.neg_all[labels].to(self.device)

            # Gather positive and negative samples
            g_pos = g_flat.gather(1, sup_tensor).mean(1)
            g_neg = g_flat.gather(1, sup_neg_tensor).mean(1)

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss


def visualize_sample(data, name='', idx=0):
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize=(4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.show()

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = Cifar_loaders()
    torch.cuda.empty_cache()

    #net = Net([3, 20, 80, 240, 480], ismaxpool=[True, True, True, True], K=10, kernels=[3, 3, 3, 3], strides=[1, 1, 1, 1], device=device)
    for i in range(5):
        net = Net([3, 20, 80, 240, 480], ismaxpool=[False, True, False, True], K=10, kernels=[3, 3, 3, 3], strides=[1, 1, 1, 1], device=device)


        x, y = next(iter(train_loader))
        x, y = x.to(device), y.to(device)

        net.train(x, y)
        x_te, y_te = next(iter(test_loader))
        x_te, y_te = x_te.to(device), y_te.to(device)
        net.predict(x_te, y_te)