import torch
import matplotlib.pyplot as plt
import copy
import torch.nn as nn
from torch.utils.data import WeightedRandomSampler
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torchvision
from tqdm import tqdm
#%%
def Cifar_loaders(train_batch_size=50000, test_batch_size=10000):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=50000,
        shuffle=True,
        num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=10000,
        shuffle=False,
        num_workers=2
    )
    return trainloader, testloader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

def get_subset(index, subindexs):
    index = set(index)
    result = []
    for subindex in subindexs:
        subindex = set(subindex)
        result.append(sorted(index - subindex))
    return result

class Cifar_Dataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]



#%%
class Layer(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, K, stride, padding=1,
                 threshold=2.0, maxpool=False, groups=1, droprate=0, device="cuda", bias=True, dtype=None):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, bias=bias, device=device, dtype=dtype)
        self.relu = torch.nn.ReLU()
        self.K = K
        self.stride = stride
        self.padding = padding
        self.threshold = threshold
        self.num_epochs = 100
        self.opt = Adam(self.parameters(), lr=0.001)
        self.device = device
        self.ismaxpool = maxpool
        self.groups = groups

        self.dropout = torch.nn.Dropout(p=droprate)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2).to(device)
        self.ismaxpool = maxpool

        # Will be initialized during first forward pass
        self.output_features = out_channels
        self.num_sup = self.output_features // K
        self.index = list(range(0, self.num_sup * K))
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(self.K)]
        self.support_index_neg = get_subset(self.index, self.support_index)

        self.pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(self.K)])
        self.neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(self.K)])
        self.pos_indices = self.pos_indices.to(self.device)
        self.neg_indices = self.neg_indices.to(self.device)

    def forward(self, x):
        # Compute normalized direction
        x = x.to(self.weight.device)
        # Forward Pass
        x = self._conv_forward(x, self.weight, self.bias)
        x = F.relu(x, inplace=True)
        # if self.ismaxpool:
        #     x = self.maxpool(x)
        # y = self.conv_bn(x)
        return x

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                response[:, i] += g[:, ind, :, :].mean((1, 2, 3))
            response = torch.clamp(response, min=-50, max=50)
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) / (torch.exp(response).sum(1) + eps)).mean()

        elif L == "CE":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                response[:, i] += g[:, ind, :, :].mean((1, 2, 3))
            loss = F.cross_entropy(response, labels)

        else:
            g = self.forward(inputs).pow(2)
            sup_channel = self.pos_indices[labels]
            sup_neg_channel = self.neg_indices[labels]
            sup_channel = sup_channel[:, :, None, None].expand(-1, -1, g.shape[-2], g.shape[-1])
            sup_neg_channel = sup_neg_channel[:, :, None, None].expand(-1, -1, g.shape[-2], g.shape[-1])

            g_pos = g.gather(1, sup_channel).mean((1, 2, 3))
            g_neg = g.gather(1, sup_neg_channel).mean((1, 2, 3))

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss
     # -----------------------------
    # Growth function
    # -----------------------------

    def growth(self, x, y, m, epochs=3, lr=1e-3, batch_size=128, train=True):
        self.eval()
        with torch.no_grad():
          #1, select the weakest class
            dataset = Cifar_Dataset(x, y)

            dataloader = DataLoader(dataset, batch_size=128, shuffle=False)

            correct_class = torch.zeros(self.K).to(self.device)
            total_class = torch.zeros(self.K).to(self.device)
            preds_all = []
            correct, total = 0, 0
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                h = inputs

                h = self.forward(h).pow(2)
                goodness = torch.stack([
                    h[:, self.support_index[k], :, :].mean(dim=(1, 2, 3))
                    for k in range(self.K)
                ], dim=1)
                preds = goodness.argmax(1)
                correct += preds.eq(labels).float().sum().item()
                total += labels.size(0)
                preds_all.append(preds)
                for k in range(self.K):
                  correct_class[k] += preds[labels==k].eq(labels[labels==k]).float().sum().item()
                  total_class[k] += torch.sum(labels==k).float()
            preds_all = torch.concat(preds_all, dim=0)

            acc_class9 = correct_class / total_class
            weakest_class = torch.argmin(acc_class9).cpu()
            #################
            #the following code has not been modified, please finish it


            #2: Compute average goodness for neurons of weakest_class ----
            mask = (y == weakest_class)

            x_class = x[mask]
            y_class = y[mask]
            if x_class.size(0) == 0:
                print(f"No samples of class {weakest_class}, skipping grow()")
                return

            preds_all = preds_all.cpu()
            mis_mask = (preds_all[mask] != weakest_class)
            x_mis = x_class[mis_mask]
            y_mis = y_class[mis_mask]
            if x_mis.size(0) == 0:
                print(f"No misclassified samples of class {weakest_class}, skipping grow()")
                return


            g = self.forward(x_class).pow(2)   # [N, neurons, w, h]
            g = g.mean(dim=(2, 3))
            avg_goodness = g.mean(dim=0)       # [neurons]

            # only consider support neurons for this class
            support_ids = self.support_index[weakest_class]
            support_goodness = avg_goodness[support_ids]

            # ---- Step 3: Select strongest-goodness neurons to clone ----
            topk = torch.topk(support_goodness, m)
            clone_idx = [support_ids[i.item()] for i in topk.indices]

            # ---- Step 4: Clone weights/bias ----
            new_weight = self.weight[clone_idx].clone().detach()
            new_bias = self.bias[clone_idx].clone().detach()

            # ---- Step 5: Expand parameters ----
            self.weight = nn.Parameter(torch.cat([self.weight, new_weight], dim=0))
            self.bias = nn.Parameter(torch.cat([self.bias, new_bias], dim=0))

            # ---- Step 6: Update support index ----
            start_idx = max(idx for idxs in self.support_index for idx in idxs) + 1
            new_indices = list(range(start_idx, start_idx + m))
            self.support_index[weakest_class].extend(new_indices)

            print(f"Grew {m} neurons (cloned strongest from class {weakest_class}), "
                  f"new total neurons = {self.weight.size(0)}")


        #using the x_mis and data with the same number of x_mis and do not belong to class_weakest to train the new weights, so that the goodness of the new neurons on the x_mis supass the goodness on the data of the other class
        if train == True:
          # force requires_grad=True
          self.train()
          x = x.cpu()
          y = y.cpu()
          weakest_class = weakest_class.cpu()
          non_mask = (y != weakest_class)
          x_non = x[non_mask]
          y_non = y[non_mask]
          if x_non.size(0) > 0:
              idx = torch.randperm(x_non.size(0))[:x_mis.size(0)]
              x_neg = x_non[idx]
              y_neg = y_non[idx]

              # combine data
              x_train = torch.cat([x_mis, x_neg], dim=0)
              y_train = torch.cat([y_mis, y_neg], dim=0)

              # optimizer only on new neurons
              opt_new = Adam([self.weight, self.bias], lr=lr)

              dataset = Cifar_Dataset(x_train, y_train)
              dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

              self.train()
              for _ in range(5):
                  for inputs, labels in dataloader:
                      inputs, labels = inputs.to(self.device), labels.to(self.device)
                      g = self.forward(inputs).pow(2)
                      # focus only on new neurons
                      g_new = g[:, new_indices].mean((1, 2, 3))

                      # target: high on misclassified class, low on others
                      mask_pos = (labels == weakest_class)
                      mask_neg = ~mask_pos
                      loss = 0
                      if mask_pos.sum() > 0:
                          loss += (-(g_new[mask_pos])).mean()
                      if mask_neg.sum() > 0:
                          loss += (g_new[mask_neg]).mean()

                      opt_new.zero_grad()
                      loss.backward()
                      # ---- gradient mask ----
                      with torch.no_grad():
                          mask = torch.zeros_like(self.weight)
                          mask[new_indices] = 1.0
                          self.weight.grad *= mask  # zero out grads for old neurons

                          mask_b = torch.zeros_like(self.bias)
                          mask_b[new_indices] = 1.0
                          self.bias.grad *= mask_b
                      opt_new.step()

        self.output_features += m
        ##this is the end of the modification
        ###################
    def delete(self, x, y, m=1):
        self.eval()
        min_goodness = torch.zeros([self.output_features, 1]).to(self.device)
        deleted_indices = []
        with torch.no_grad():
            dataset = Cifar_Dataset(x, y)
            dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
            total = 0
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                total += labels.shape[0]
                g = self.forward(inputs).pow(2)

                for id in range(self.output_features):
                    min_goodness[id, :] += g[:, id, :, :].mean((1, 2)).sum(0)


            val, min_neuron = torch.min(min_goodness/total, dim=0)
            deleted_indices.append(min_neuron)
            self.weight.data = torch.cat([self.weight.data[:min_neuron],
                                              self.weight.data[min_neuron+1:]], dim=0)
            self.bias.data = torch.cat([self.bias.data[:min_neuron],
                                            self.bias.data[min_neuron+1:]], dim=0)

            for k in range(self.K):
                    self.support_index[k] = [i if i < min_neuron else i-1 for i in self.support_index[k] if i != min_neuron]
                    self.support_index_neg[k] = [i if i < min_neuron else i-1 for i in self.support_index_neg[k] if i != min_neuron]

        print("Deleted neurons from support sets:", deleted_indices)
        self.output_features -= 1
        del dataset, dataloader
        return deleted_indices

    def _compute_goodness(self, h):
        """Compute goodness score for conv or linear layer."""
        g = h.pow(2)
        goodness = torch.stack([
            g[:, self.support_index[k], :, :].mean(dim=(1, 2, 3))
            for k in range(self.K)
        ], dim=1)
        return goodness.argmax(dim=1)


    def predict(self, x, y):
        """Evaluate accuracy layer by layer."""
        dataset = Cifar_Dataset(x, y)

        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


        with torch.no_grad():
            correct, total = 0, 0
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                h = self.forward(inputs)
                preds = self._compute_goodness(h)
                correct += preds.eq(labels).sum().item()
                total += labels.size(0)

            layer_acc = correct / total

        del dataset
        del dataloader
        return layer_acc


    def train_layer(self, x, y):
        dataset = Cifar_Dataset(x, y)

        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
        for j in tqdm(range(self.num_epochs)):
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(device)

                loss = self.loss(inputs, labels, L='CWC')

                # Optimization step
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
        del dataset
        del dataloader

    def add_bn(self, out_dim):
        self.conv_bn = nn.BatchNorm2d(out_dim, eps=1e-4).to(device)


class Linear_Layer(nn.Linear):
    def __init__(self, in_features, out_features, K, threshold=2.0, device="cuda",
                 bias=True, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.num_sup = out_features // K
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(K)]
        self.support_index_neg = [[i for i in list(range(out_features)) if i not in self.support_index[k]] for k in
                                  range(K)]
        self.opt = Adam(self.parameters(), lr=0.001)
        self.threshold = threshold
        self.num_epochs = 100
        self.device = device
        pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(K)])
        neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(K)])
        self.pos_all = pos_indices.to(self.device)
        self.neg_all = neg_indices.to(self.device)
        self.K = K

    def forward(self, x):
        x = x.to(self.weight.device)
        if x.ndim == 4:
            batchsize, channel, height, width = x.shape
            x = x.reshape(batchsize, -1)
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            # Compute loss
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) / (torch.exp(response).sum(1) + eps)).mean()

        elif L == "CE":
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_all):
                response[:, i] += g_flat[:, ind].mean(1)

            loss = F.cross_entropy(response, labels)

        else:
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            # Prepare support indices
            sup_tensor = self.pos_all[labels].to(self.device)
            sup_neg_tensor = self.neg_all[labels].to(self.device)

            # Gather positive and negative samples
            g_pos = g_flat.gather(1, sup_tensor).mean(1)
            g_neg = g_flat.gather(1, sup_neg_tensor).mean(1)

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss

#%%
class Net(torch.nn.Module):
    def __init__(self, K, threshold=2.0, device='cuda'):
        super().__init__()
        layers = []
        self.K = K
        self.dims = dims
        self.layers_name = []
        self.device = device
        self.threshold = threshold
        # for d in range(len(dims) - 1):
        #     layers += [Layer(dims[d], dims[d + 1], maxpool=self.ismaxpool[d], K=self.K, kernel_size=kernels[d],
        #                      stride=strides[d], threshold=threshold, device=device)]
        #     self.layers_name.append('conv')
        # Classifier
        self.layers = nn.ModuleList()

    def add_layer(self, input_dim, out_dim, ismaxpool, kernels, strides):
        self.layers.append(Layer(input_dim, out_dim, maxpool=ismaxpool, K=self.K, kernel_size=kernels, stride=strides, threshold=self.threshold, device=self.device).to(self.device))


    def _compute_goodness(self, h, layer, labels=None):
        """Compute goodness score for conv or linear layer."""
        if isinstance(layer, Linear_Layer):
            h = h.view(h.size(0), -1)
            g = h.pow(2)
            goodness = torch.stack([
                g[:, layer.support_index[k]].mean(dim=1) for k in range(self.K)
            ], dim=1)
        else:  # Conv layer
            g = h.pow(2)
            goodness = torch.stack([
                g[:, layer.support_index[k], :, :].mean(dim=(1, 2, 3))
                for k in range(self.K)
            ], dim=1)
        return goodness.argmax(dim=1)

    def predict(self, x, y):
        """Evaluate accuracy layer by layer."""
        dataset = Cifar_Dataset(x, y)

        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

        layer_acc = torch.zeros(len(self.layers)).to(self.device)
        layer_class = torch.zeros(len(self.layers), self.K).to(self.device)
        previous_layers = []

        with torch.no_grad():
            for layer_idx, layer in enumerate(self.layers):

                correct, total = 0, 0
                correct_class = torch.zeros(self.K).to(self.device)
                total_class = torch.zeros(self.K).to(self.device)
                for inputs, labels in dataloader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    h = inputs

                    for pre_layer in previous_layers:
                        h = pre_layer(h)
                        if not isinstance(pre_layer, Linear_Layer):
                            if pre_layer.ismaxpool == True:
                                h = pre_layer.maxpool(h)
                            h = pre_layer.conv_bn(h)

                    h = layer(h)
                    preds = self._compute_goodness(h, layer)
                    correct += preds.eq(labels).sum().item()
                    total += labels.size(0)
                    for k in range(self.K):
                      correct_class[k] += preds[labels==k].eq(labels[labels==k]).float().sum().item()
                      total_class[k] += torch.sum(labels==k).float()

                layer_acc[layer_idx] = correct / total
                layer_class[layer_idx] += correct_class / total_class
                previous_layers.append(layer)

        return layer_acc[-1], layer_class

    def train(self, x, y):
        h = x
        dataset = Cifar_Dataset(h, y)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
        previous_layers = []
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')

            for j in tqdm(range(layer.num_epochs)):
                for inputs, labels in dataloader:
                    inputs, labels = inputs.to(self.device), labels.to(device)
                    h = inputs

                    with torch.no_grad():
                        for pre_layer in previous_layers:
                            h = pre_layer(h.detach())
                            if not isinstance(pre_layer, Linear_Layer):
                                if pre_layer.ismaxpool == True:
                                    h = pre_layer.maxpool(h)
                                h = pre_layer.conv_bn(h)

                    loss = layer.loss(h, labels, L='CE')

                    # Optimization step
                    layer.opt.zero_grad()
                    loss.backward()
                    layer.opt.step()
            torch.cuda.empty_cache()
            previous_layers.append(layer)

    def save(self, path):
        torch.save({
            'state_dict': self.state_dict(),
            'config': {
                'dims': self.dims,
                'K': self.K,
                'kernels': self.kernels,
                'strides': self.strides,
                'threshold': 2.0
            }
        }, path)
        print(f"Model saved to {path}")

    def load(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['state_dict'])
        print(f"Model loaded from {path}")

    def _plot_acc_curve(self, accuracies):
        plt.figure(figsize=(5, 5))
        plt.plot(range(len(self.layers)), accuracies, label='Accuracy', color='blue')
        plt.title('Accuracy by Layer')
        plt.xlabel('Layer Index')
        plt.ylabel('Accuracy')
        plt.ylim(0, 1.0)

        output_path = 'conv_acc_curve.png'
        plt.savefig(output_path)
        plt.close()
        print(f"Accuracy curve saved to {output_path}")
        print("Layer accuracies:", [f"{acc:.4f}" for acc in accuracies])


device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

train_loader, test_loader = Cifar_loaders()

maxpool = [False, True, False, True]
kernels = [3, 3, 3, 3]
strides = [1, 1, 1, 1]
dims = [3, 20, 80, 240, 480]

for _ in range(5):
    x_train, y_train = next(iter(train_loader))
    x_train, y_train = x_train, y_train

    x_test, y_test = next(iter(test_loader))
    x_test, y_test = x_test, y_test
    net = Net(K=10, threshold=2.0, device=device)
    for i in range(len(dims) - 1):
        if i == 0:
            net.add_layer(input_dim=dims[i], out_dim=dims[i + 1], ismaxpool=maxpool[i], kernels=kernels[i], strides=strides[i])
        else:
            net.add_layer(input_dim=x_train.shape[1], out_dim=dims[i + 1], ismaxpool=maxpool[i], kernels=kernels[i], strides=strides[i])
        print(f"Training layer {i} ...")
        net.layers[i].train_layer(x_train, y_train)
        with torch.no_grad():
            acc_all, _ = net.predict(x_test, y_test)
            print(f"Last layer -> Overall Acc: {acc_all:.4f}")
        layer1 = copy.deepcopy(net.layers[i])
        layer1.delete(x_train, y_train, m=1)
        while layer1.predict(x_train, y_train) >= net.layers[i].predict(x_train, y_train):
            net.layers[i] = layer1
            layer1 = copy.deepcopy(net.layers[i])
            layer1.delete(x_train, y_train, m=1)
        layer1 = copy.deepcopy(net.layers[i])
        layer1.growth(x_train, y_train, m=1)
        while layer1.predict(x_train, y_train) >= net.layers[i].predict(x_train, y_train):
            net.layers[i] = layer1
            layer1 = copy.deepcopy(net.layers[i])
            layer1.growth(x_train, y_train, m=1)
        del layer1
        x_train_all = []
        h_all = []
        flag = False
        with torch.no_grad():
            dataset = Cifar_Dataset(x_train, y_train)
            dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=0)
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                inputs = net.layers[i].forward(inputs).detach()
                if net.layers[i].ismaxpool:
                    inputs = net.layers[i].maxpool(inputs)
                if not flag:
                    net.layers[i].add_bn(inputs.shape[1])
                    flag = True
                inputs = net.layers[i].conv_bn(inputs).detach()
                x_train_all.append(inputs.cpu())
            x_train = torch.cat(x_train_all, dim=0)
        print(x_train.shape[1])
        with torch.no_grad():
            acc_all, _ = net.predict(x_test, y_test)
            print(f"Last layer -> Overall Acc: {acc_all:.4f}")
