import torch
import copy
import torch.nn as nn
from torch.utils.data import WeightedRandomSampler
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
#%%
# -----------------------------
# MNIST dataset wrapper
# -----------------------------
class MNIST_Dataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


def make_balanced_loader(x, y, batch_size=64):
    # Count class frequencies
    class_counts = torch.bincount(y)
    weights = 1.0 / class_counts.float()
    sample_weights = weights[y]

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    dataset = MNIST_Dataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return dataloader

# -----------------------------
# Custom Layer
# -----------------------------
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, K, threshold=2.0, device="cuda:0",
                 bias=True, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = nn.ReLU()
        self.K = K
        self.out_features = out_features
        self.device = device

        # Partition neurons into K support sets
        base = out_features // K
        self.support_index = []
        for k in range(K):
            start = k * base
            end = (k + 1) * base
            self.support_index.append(list(range(start, end)))
        # Put remainder into the last class
        self.support_index[-1] = list(range((K - 1) * base, out_features))

        self.support_index_neg = [
            [i for i in range(out_features) if i not in self.support_index[k]]
            for k in range(K)
        ]

        self.opt = Adam(self.parameters(), lr=0.001)
        self.threshold = threshold
        self.num_epochs = 1  # keep small for demo


    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(F.linear(x_direction, self.weight, self.bias))

    def train_layer(self, x, y):
        dataset = MNIST_Dataset(x, y)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

        for _ in tqdm(range(self.num_epochs)):
            for inputs, labels in dataloader:
                loss = self.loss(inputs, labels, L=0)

                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

        return self.forward(x).detach()

    def get_neuron(self, x, y):
        g = self.forward(x).pow(2)
        sup = [self.support_index[int(i)] for i in y]
        sup_neg = [self.support_index_neg[int(i)] for i in y]
        g_pos = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup)])
        g_neg = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup_neg)])
        return [sup, sup_neg, g_pos, g_neg]

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                if len(ind) > 0:
                    response[:, i] += g[:, ind].mean(1)
            response = torch.clamp(response, min=-50, max=50)
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) /
                              (torch.exp(response).sum(1) + eps)).mean()
        else:
            g = self.forward(inputs).pow(2)

            sup = [self.support_index[int(i)] for i in labels]
            sup_neg = [self.support_index_neg[int(i)] for i in labels]

            g_pos = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup)])
            g_neg = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup_neg)])

            loss = torch.log(1 + torch.exp(torch.cat([
                    -g_pos + self.threshold,
                    g_neg - self.threshold
                ]))).mean()

        return loss

    # -----------------------------
    # Growth function
    # -----------------------------
    def growth(self, x, y, m, epochs=3, lr=1e-3, batch_size=128, train=True):
        self.eval()
        with torch.no_grad():
          #1, select the weakest class
            h = self.forward(x).pow(2)
            inputs = h.clone()
            goodness = torch.zeros(x.shape[0], self.K, device=x.device)
            for k in range(self.K):
                goodness[:, k] = h[:, self.support_index[k]].mean(1)
            preds = goodness.argmax(1)
            acc_class9 = []
            for k in range(self.K):
              acc_class9.append(preds[y == k].eq(y[y == k]).float().mean().item())
            weakest_class = torch.argmin(torch.tensor(acc_class9))

            #2: Compute average goodness for neurons of weakest_class ----
            mask = (y == weakest_class)
            x_class = x[mask]
            y_class = y[mask]
            if x_class.size(0) == 0:
                print(f"No samples of class {weakest_class}, skipping grow()")
                return


            ##########
            mis_mask = (preds[mask] != weakest_class)
            x_mis = x_class[mis_mask]
            y_mis = y_class[mis_mask]
            if x_mis.size(0) == 0:
                print(f"No misclassified samples of class {weakest_class}, skipping grow()")
                return
            ##########

            g = self.forward(x_class).pow(2)   # [N, neurons]
            avg_goodness = g.mean(dim=0)       # [neurons]

            # only consider support neurons for this class
            support_ids = self.support_index[weakest_class]
            support_goodness = avg_goodness[support_ids]

            # ---- Step 3: Select strongest-goodness neurons to clone ----
            topk = torch.topk(support_goodness, m)
            clone_idx = [support_ids[i.item()] for i in topk.indices]

            # ---- Step 4: Clone weights/bias ----
            new_weight = self.weight[clone_idx].clone().detach()
            new_bias = self.bias[clone_idx].clone().detach()

            # ---- Step 5: Expand parameters ----
            self.weight = nn.Parameter(torch.cat([self.weight, new_weight], dim=0))
            self.bias = nn.Parameter(torch.cat([self.bias, new_bias], dim=0))

            # ---- Step 6: Update support index ----
            start_idx = max(idx for idxs in self.support_index for idx in idxs) + 1
            new_indices = list(range(start_idx, start_idx + m))
            self.support_index[weakest_class].extend(new_indices)

            print(f"Grew {m} neurons (cloned strongest from class {weakest_class}), "
                  f"new total neurons = {self.weight.size(0)}")

        ##########
        #using the x_mis and data with the same number of x_mis and do not belong to class_weakest to train the new weights, so that the goodness of the new neurons on the x_mis supass the goodness on the data of the other class
        if train == True:
          # force requires_grad=True
          self.train()
          non_mask = (y != weakest_class)
          x_non = x[non_mask]
          y_non = y[non_mask]
          if x_non.size(0) > 0:
              idx = torch.randperm(x_non.size(0))[:x_mis.size(0)]
              x_neg = x_non[idx]
              y_neg = y_non[idx]

              # combine data
              x_train = torch.cat([x_mis, x_neg], dim=0)
              y_train = torch.cat([y_mis, y_neg], dim=0)

              # optimizer only on new neurons
              opt_new = Adam([self.weight, self.bias], lr=lr)

              dataset = MNIST_Dataset(x_train, y_train)
              dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

              self.train()
              for _ in range(5):
                  for inputs, labels in dataloader:
                      g = self.forward(inputs).pow(2)
                      # focus only on new neurons
                      g_new = g[:, new_indices].mean(1)

                      # target: high on misclassified class, low on others
                      mask_pos = (labels == weakest_class)
                      mask_neg = ~mask_pos
                      loss = 0
                      if mask_pos.sum() > 0:
                          loss += (-(g_new[mask_pos])).mean()
                      if mask_neg.sum() > 0:
                          loss += (g_new[mask_neg]).mean()

                      opt_new.zero_grad()
                      loss.backward()
                      # ---- gradient mask ----
                      with torch.no_grad():
                          mask = torch.zeros_like(self.weight)
                          mask[new_indices] = 1.0
                          self.weight.grad *= mask  # zero out grads for old neurons

                          mask_b = torch.zeros_like(self.bias)
                          mask_b[new_indices] = 1.0
                          self.bias.grad *= mask_b
                      opt_new.step()
        ###############

    def delete(self, x, y, m=1):
        g = self.forward(x).pow(2)
        deleted_indices = []
        for _ in range(m):
            min_goodness = float("inf")
            min_neuron, min_class = None, None
            for k in range(self.K):
                idxs = self.support_index[k]
                class_mask = (y == k)
                if class_mask.sum() == 0:
                    continue
                class_goodness = g[class_mask][:, idxs].mean(0)
                val, pos = torch.min(class_goodness, dim=0)
                if val.item() < min_goodness:
                    min_goodness = val.item()
                    min_neuron = idxs[pos.item()]
                    min_class = k

            deleted_indices.append((min_class, min_neuron))
            self.weight.data = torch.cat([self.weight.data[:min_neuron],
                                          self.weight.data[min_neuron+1:]], dim=0)
            self.bias.data = torch.cat([self.bias.data[:min_neuron],
                                        self.bias.data[min_neuron+1:]], dim=0)

            for k in range(self.K):
                self.support_index[k] = [i if i < min_neuron else i-1 for i in self.support_index[k] if i != min_neuron]
                self.support_index_neg[k] = [i if i < min_neuron else i-1 for i in self.support_index_neg[k] if i != min_neuron]

        print("Deleted neurons from support sets:", deleted_indices)
        return deleted_indices

# -----------------------------
# Net
# -----------------------------
class Net(nn.Module):
    def __init__(self, dims, K, threshold=2.0, device="cuda:0"):
        super().__init__()
        self.layers = nn.ModuleList()
        self.K = K
        for d in range(len(dims) - 1):
            self.layers.append(Layer(dims[d], dims[d + 1], K=self.K,
                                     threshold=threshold, device=device).to(device))

    def predict(self, x, y):
        inputs = x
        for layer in self.layers:
            h = layer(inputs)
            inputs = h.clone()
            goodness = torch.zeros(x.shape[0], self.K, device=x.device)
            for k in range(self.K):
                goodness[:, k] = h[:, layer.support_index[k]].mean(1)
            preds = goodness.argmax(1)
        acc_overall = preds.eq(y).float().mean().item()
        acc_class9 = []
        for k in range(self.K):
          acc_class9.append(preds[y == k].eq(y[y == k]).float().mean().item())
        return acc_overall, acc_class9

    def train_net(self, x, y):
        h = x
        for i, layer in enumerate(self.layers):
            print(f"Training layer {i} ...")
            h = layer.train_layer(h, y)
        return h
#%%
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load MNIST
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: torch.flatten(x))])

trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

x_train = trainset.data.view(-1, 784).float().to(device)
y_train = trainset.targets.to(device)

x_test = testset.data.view(-1, 784).float().to(device)
y_test = testset.targets.to(device)

    # Build net (one layer)
net = Net([784, 50], K=10, threshold=2.0, device=device)
net.train_net(x_train, y_train)

    # Test before growth
acc_all, acc_class9 = net.predict(x_test, y_test)
print(f"Before growth -> Overall Acc: {acc_all:.4f}, Class-9 Acc: {acc_class9}")


layer1 = net.layers[0]
for i in range(2):
  layer1.delete(x_train, y_train, m=1)
  acc_all, acc_class9 = net.predict(x_test, y_test)
  print(f"Net2 (Growth w training) -> Overall Acc: {acc_all:.4f}, Class-9 Acc: {acc_class9}")