import copy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import datasets
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda, transforms
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import torch.nn.functional as F

from MNIST import MNIST_loaders


class MNIST_Dataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def make_balanced_loader(x, y, batch_size=64):
    # Count class frequencies
    class_counts = torch.bincount(y)
    weights = 1.0 / class_counts.float()
    sample_weights = weights[y]

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    dataset = MNIST_Dataset(x, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return dataloader


def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


# -----------------------------
# Custom Layer
# -----------------------------
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, K, threshold=2.0, device="cuda:0",
                 bias=True, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.num_sup = out_features // K
        self.support_index = [[i + self.num_sup * k for i in range(self.num_sup)] for k in range(K)]
        self.support_index_neg = [[i for i in list(range(out_features)) if i not in self.support_index[k]] for k in
                                  range(K)]
        self.opt = Adam(self.parameters(), lr=0.001)
        self.threshold = threshold
        self.num_epochs = 100
        self.device = device
        pos_indices = torch.stack([torch.tensor(self.support_index[i]) for i in range(K)])
        neg_indices = torch.stack([torch.tensor(self.support_index_neg[i]) for i in range(K)])
        self.pos_all = pos_indices.to(self.device)
        self.neg_all = neg_indices.to(self.device)
        self.K = K
        self.output_features = out_features


    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(F.linear(x_direction, self.weight, self.bias))

    def train_layer(self, x, y):
        dataset = MNIST_Dataset(x, y)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

        for _ in tqdm(range(self.num_epochs)):
            for inputs, labels in dataloader:
                loss = self.loss(inputs, labels, L='PvN')

                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

        return self.forward(x).detach()

    def get_neuron(self, x, y):
        g = self.forward(x).pow(2)
        sup = [self.support_index[int(i)] for i in y]
        sup_neg = [self.support_index_neg[int(i)] for i in y]
        g_pos = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup)])
        g_neg = torch.stack([g[b, idx].mean() for b, idx in enumerate(sup_neg)])
        return [sup, sup_neg, g_pos, g_neg]

    def loss(self, inputs, labels, L="CWC"):
        if L == "CWC":
            g = self.forward(inputs).pow(2)
            response = torch.zeros([g.size(0), self.K]).to(self.device)
            for i, ind in enumerate(self.pos_indices):
                if len(ind) > 0:
                    response[:, i] += g[:, ind].mean(1)
            response = torch.clamp(response, min=-50, max=50)
            chosen = response.gather(1, labels.unsqueeze(1)).squeeze(1)

            eps = 1e-9
            loss = -torch.log((torch.exp(chosen) + eps) /
                              (torch.exp(response).sum(1) + eps)).mean()
        else:
            g = self.forward(inputs).pow(2)

            # Flatten the output except batch dimension
            batch_size = g.size(0)
            g_flat = g.reshape(batch_size, -1).to(self.device)

            # Prepare support indices
            sup_tensor = self.pos_all[labels].to(self.device)
            sup_neg_tensor = self.neg_all[labels].to(self.device)

            # Gather positive and negative samples
            g_pos = g_flat.gather(1, sup_tensor).mean(1)
            g_neg = g_flat.gather(1, sup_neg_tensor).mean(1)

            # Compute loss
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()

        return loss

    # -----------------------------
    # Growth function
    # -----------------------------
    def growth(self, x, y, m, epochs=3, lr=1e-3, batch_size=128, train=True):
        self.eval()
        with torch.no_grad():
          #1, select the weakest class
            h = self.forward(x).pow(2)
            inputs = h.clone()
            goodness = torch.zeros(x.shape[0], self.K, device=x.device)
            for k in range(self.K):
                goodness[:, k] = h[:, self.support_index[k]].mean(1)
            preds = goodness.argmax(1)
            acc_class9 = []
            for k in range(self.K):
              acc_class9.append(preds[y == k].eq(y[y == k]).float().mean().item())
            weakest_class = torch.argmin(torch.tensor(acc_class9))

            #2: Compute average goodness for neurons of weakest_class ----
            mask = (y == weakest_class)
            x_class = x[mask]
            y_class = y[mask]
            if x_class.size(0) == 0:
                print(f"No samples of class {weakest_class}, skipping grow()")
                return


            ##########
            mis_mask = (preds[mask] != weakest_class)
            x_mis = x_class[mis_mask]
            y_mis = y_class[mis_mask]
            if x_mis.size(0) == 0:
                print(f"No misclassified samples of class {weakest_class}, skipping grow()")
                return
            ##########

            g = self.forward(x_class).pow(2)   # [N, neurons]
            avg_goodness = g.mean(dim=0)       # [neurons]

            # only consider support neurons for this class
            support_ids = self.support_index[weakest_class]
            support_goodness = avg_goodness[support_ids]

            # ---- Step 3: Select strongest-goodness neurons to clone ----
            topk = torch.topk(support_goodness, m)
            clone_idx = [support_ids[i.item()] for i in topk.indices]

            # ---- Step 4: Clone weights/bias ----
            new_weight = self.weight[clone_idx].clone().detach()
            new_bias = self.bias[clone_idx].clone().detach()

            # ---- Step 5: Expand parameters ----
            self.weight = nn.Parameter(torch.cat([self.weight, new_weight], dim=0))
            self.bias = nn.Parameter(torch.cat([self.bias, new_bias], dim=0))

            # ---- Step 6: Update support index ----
            start_idx = max(idx for idxs in self.support_index for idx in idxs) + 1
            new_indices = list(range(start_idx, start_idx + m))
            self.support_index[weakest_class].extend(new_indices)

            print(f"Grew {m} neurons (cloned strongest from class {weakest_class}), "
                  f"new total neurons = {self.weight.size(0)}")

        ##########
        #using the x_mis and data with the same number of x_mis and do not belong to class_weakest to train the new weights, so that the goodness of the new neurons on the x_mis supass the goodness on the data of the other class
        if train == True:
          # force requires_grad=True
          self.train()
          non_mask = (y != weakest_class)
          x_non = x[non_mask]
          y_non = y[non_mask]
          if x_non.size(0) > 0:
              idx = torch.randperm(x_non.size(0))[:x_mis.size(0)]
              x_neg = x_non[idx]
              y_neg = y_non[idx]

              # combine data
              x_train = torch.cat([x_mis, x_neg], dim=0)
              y_train = torch.cat([y_mis, y_neg], dim=0)

              # optimizer only on new neurons
              opt_new = Adam([self.weight, self.bias], lr=lr)

              dataset = MNIST_Dataset(x_train, y_train)
              dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

              self.train()
              for _ in range(5):
                  for inputs, labels in dataloader:
                      g = self.forward(inputs).pow(2)
                      # focus only on new neurons
                      g_new = g[:, new_indices].mean(1)

                      # target: high on misclassified class, low on others
                      mask_pos = (labels == weakest_class)
                      mask_neg = ~mask_pos
                      loss = 0
                      if mask_pos.sum() > 0:
                          loss += (-(g_new[mask_pos])).mean()
                      if mask_neg.sum() > 0:
                          loss += (g_new[mask_neg]).mean()

                      opt_new.zero_grad()
                      loss.backward()
                      # ---- gradient mask ----
                      with torch.no_grad():
                          mask = torch.zeros_like(self.weight)
                          mask[new_indices] = 1.0
                          self.weight.grad *= mask  # zero out grads for old neurons

                          mask_b = torch.zeros_like(self.bias)
                          mask_b[new_indices] = 1.0
                          self.bias.grad *= mask_b
                      opt_new.step()
        ###############

    def delete(self, x, y, m=1):
        g = self.forward(x).pow(2)
        deleted_indices = []
        for _ in range(m):
            min_goodness = float("inf")
            min_neuron, min_class = None, None
            for k in range(self.K):
                idxs = self.support_index[k]
                class_mask = (y == k)
                if class_mask.sum() == 0:
                    continue
                class_goodness = g[class_mask][:, idxs].mean(0)
                val, pos = torch.min(class_goodness, dim=0)
                if val.item() < min_goodness:
                    min_goodness = val.item()
                    min_neuron = idxs[pos.item()]
                    min_class = k

            deleted_indices.append((min_class, min_neuron))
            self.weight.data = torch.cat([self.weight.data[:min_neuron],
                                          self.weight.data[min_neuron+1:]], dim=0)
            self.bias.data = torch.cat([self.bias.data[:min_neuron],
                                        self.bias.data[min_neuron+1:]], dim=0)

            for k in range(self.K):
                self.support_index[k] = [i if i < min_neuron else i-1 for i in self.support_index[k] if i != min_neuron]
                self.support_index_neg[k] = [i if i < min_neuron else i-1 for i in self.support_index_neg[k] if i != min_neuron]

        print("Deleted neurons from support sets:", deleted_indices)
        self.output_features -= 1
        return deleted_indices


    def predict(self, x, y):
        inputs = x
        h = self.forward(inputs)
        goodness = torch.zeros(x.shape[0], self.K, device=x.device)
        for k in range(self.K):
            goodness[:, k] = h[:, self.support_index[k]].mean(1)
        preds = goodness.argmax(1)
        acc_overall = preds.eq(y).float().mean().item()
        acc_class9 = []
        for k in range(self.K):
            acc_class9.append(preds[y == k].eq(y[y == k]).float().mean().item())
        return acc_overall, acc_class9

# -----------------------------
# Net
# -----------------------------
class Net(nn.Module):
    def __init__(self, K, threshold=2.0, device="cuda:0"):
        super().__init__()
        self.layers = nn.ModuleList()
        self.K = K
        self.threshold = threshold
        self.device = device

    def predict(self, x, y):
        inputs = x
        for layer in self.layers:
            h = layer(inputs)
            inputs = h.clone()
            goodness = torch.zeros(x.shape[0], self.K, device=x.device)
            for k in range(self.K):
                goodness[:, k] = h[:, layer.support_index[k]].mean(1)
            preds = goodness.argmax(1)
        acc_overall = preds.eq(y).float().mean().item()
        acc_class9 = []
        for k in range(self.K):
          acc_class9.append(preds[y == k].eq(y[y == k]).float().mean().item())
        return acc_overall, acc_class9

    def train_net(self, x, y):
        h = x
        for i, layer in enumerate(self.layers):
            print(f"Training layer {i} ...")
            h = layer.train_layer(h, y)
        return h

    def add_layer(self, input_dim, out_dim):
        self.layers.append(Layer(input_dim, out_dim, K=self.K,
                                 threshold=self.threshold, device=self.device).to(self.device))

device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: torch.flatten(x))])

train_loader, test_loader = MNIST_loaders(train_batch_size=50000)




# Build net (one layer)
dim = [784, 100, 100]

for _ in range(5):
    x, y = next(iter(train_loader))
    x_train, y_train = x.to(device), y.to(device)

    x_te, y_te = next(iter(test_loader))
    x_test, y_test = x_te.to(device), y_te.to(device)
    net = Net(K=10, threshold=2.0, device=device)

    for i in range(2):
        if i == 0:
            net.add_layer(input_dim=dim[i], out_dim=dim[i + 1])
        else:
            net.add_layer(input_dim=x_train.shape[-1], out_dim=dim[i + 1])
        print(f"Training layer {i} ...")
        net.layers[i].train_layer(x_train, y_train)
        with torch.no_grad():
            acc_all, _ = net.predict(x_test, y_test)
            print(f"Last layer -> Overall Acc: {acc_all:.4f}")
        layer1 = copy.deepcopy(net.layers[i])
        layer1.delete(x_train, y_train, m=1)
        while layer1.predict(x_train, y_train)[0] >= net.layers[i].predict(x_train, y_train)[0]:
            net.layers[i] = layer1
            layer1 = copy.deepcopy(net.layers[i])
            layer1.delete(x_train, y_train, m=1)
        layer1 = copy.deepcopy(net.layers[i])
        layer1.growth(x_train, y_train, m=1)
        while layer1.predict(x_train, y_train)[0] >= net.layers[i].predict(x_train, y_train)[0]:
            net.layers[i] = layer1
            layer1 = copy.deepcopy(net.layers[i])
            layer1.growth(x_train, y_train, m=1)
        x_train = net.layers[i].forward(x_train).detach()
        print(x_train.shape[1])
        with torch.no_grad():
            acc_all, _ = net.predict(x_test, y_test)
            print(f"Last layer -> Overall Acc: {acc_all:.4f}")
