import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from networks import feature_extractors, utils

class FullNet_RidgeRegressor(nn.Module):
    def __init__(self, backbone, reg=1e-6):
        super().__init__()
        self.backbone = backbone
        self.reg = reg

    def forward(self, x, xs, ys, reg=None):        
        if reg is None:
            reg = self.reg
        fs = self.backbone(xs)
        d = fs.shape[1]
        W = torch.inverse(fs.t()@fs + reg*torch.eye(d).to(fs.device)) @ (fs.t() @ ys)
        f = self.backbone(x)
        logits = f @ W
        return logits

def get_full_network(feature, head, opts):
    feat_ext = feature_extractors.get_backbone(feature)
    if head == 'ridge':
        full_net = FullNet_RidgeRegressor(feat_ext, opts.ridge_reg)
        derived_head = True
    else:
        raise NotImplementedError

    return full_net, derived_head

