import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock_s(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock_s, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet_s(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet_s, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

        self.gradients = None

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, return_feat=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feat = out.view(out.size(0), -1)
        out = self.linear(feat)
        if return_feat:
            return out, feat.detach().clone()#y, out
        else:
            return out

    @torch.no_grad()
    def get_features(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return out

    def get_ntk_features(self, x, y):
        """
        One-backward-per-sample random projections of per-sample gradients.
        Returns (B, K) features on CPU
        """
        K = 64  # projection dim
        microbatch_size = 8
        normalize_per_param = True
        seed = 0  # set None to disable deterministic dirs

        y = y.view(-1).long()
        was_training = self.training
        self.eval()

        # Trainable params in canonical order
        params_used = [p for p in self.parameters() if p.requires_grad]
        if not params_used:
            return torch.zeros(x.size(0), K, dtype=torch.float32)

        device = params_used[0].device
        dtype = params_used[0].dtype

        # build RP projection matrices
        P = sum(p.numel() for p in params_used)
        g = torch.Generator(device=device)
        if seed is not None:
            g.manual_seed(int(seed))

        R = torch.empty(K, P, device=device, dtype=dtype)
        with torch.no_grad():
            for k in range(K):
                off = 0
                for p in params_used:
                    r = torch.randn_like(p)
                    if normalize_per_param:
                        rf = r.view(-1)
                        n = torch.linalg.vector_norm(rf)
                        if n > 0:
                            rf = rf / n
                        R[k, off:off + p.numel()].copy_(rf)
                    else:
                        R[k, off:off + p.numel()].copy_(r.view(-1))
                    off += p.numel()

        B = x.size(0)
        feats_dev = torch.empty(B, K, device=device, dtype=torch.float32)

        # scratch buffer
        g_flat = torch.empty(P, device=device, dtype=dtype)
        row = torch.empty(K, device=device, dtype=torch.float32)

        for start in range(0, B, microbatch_size):
            end = min(start + microbatch_size, B)
            xb = x[start:end].to(device, non_blocking=True)
            yb = y[start:end].to(device, non_blocking=True)

            # fwd once for the microbatch
            logits = self.forward(xb)
            true_logit = logits.gather(1, yb[:, None]).squeeze(1)
            mb = true_logit.size(0)

            for i in range(mb):
                s = true_logit[i]
                retain = (i < mb - 1)

                grads = torch.autograd.grad(
                    outputs=s,
                    inputs=params_used,
                    retain_graph=retain,
                    create_graph=False,
                    allow_unused=True,
                )

                # pack grads into flat vector g_flat
                off = 0
                for g_param, p in zip(grads, params_used):
                    if g_param is None:
                        g_flat[off:off + p.numel()].zero_()
                    else:
                        g_flat[off:off + p.numel()].copy_(g_param.view(-1))
                    off += p.numel()

                # random projection: row = R @ g_flat
                with torch.no_grad():
                    row.copy_(torch.mv(R.to(torch.float32), g_flat.to(torch.float32)))
                    feats_dev[start + i].copy_(row)

            del logits, true_logit

        self.train(was_training)
        return feats_dev.cpu()

    def classifier(self, feat):
        return self.linear(feat)
