import torch
import torch.nn as nn
from copy import deepcopy
import clip

class Net(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True):
        super(Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.bias = bias

        self.fc = nn.Linear(in_dim, out_dim, bias=bias)

        self.seen_classes = []

    def forward(self, x, normalize=False):
        if normalize:
            x = x / x.norm(dim=-1, keepdim=True)
            unit_w = self.fc.weight / self.fc.weight.norm(dim=-1, keepdim=True)
            out = 100 * x @ unit_w.T
        else:
            out = self.fc(x)
        return out

    def make_head(self, new_dim, c, clip_init=None):
        # new_dim: size of dimension to add, must be 1. c: class name
        # clip_init: a pretrained clip model
        device = self.fc.weight.device
        if c not in self.seen_classes:
            self.seen_classes.append(c)

            self.total_dim = len(self.seen_classes)
            self.fc1 = deepcopy(self.fc)

            self.fc = nn.Linear(self.in_dim, self.total_dim, self.bias).to(device)
            self.fc.weight.data[:self.out_dim, :] = self.fc1.weight.data

            if clip_init is not None:
                text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}")]).to(device)
                text_feature = clip_init.encode_text(text_inputs).type(torch.FloatTensor).to(device)
                self.fc.weight.data[-1, :] = text_feature.data

            if self.bias:
                self.fc.bias.data[:self.out_dim] = self.fc1.bias.data

            self.out_dim = self.total_dim
            del self.fc1


