import torch.nn as nn
from torch.nn import functional as F
import torch


class network(nn.Module):
    def __init__(self, numclass, feature_extractor, args):
        super(network, self).__init__()
        self.feature = feature_extractor
        self.old_class_num = None
        self.known_class_num = args.task_classes[0]
        self.fc = nn.Linear(feature_extractor.fc.in_features, numclass, bias=False)
        self.encoder = nn.Linear(feature_extractor.fc.in_features, feature_extractor.fc.in_features)

    def forward(self, input):
        x = self.feature(input)
        x = self.fc(x)
        return x

    def Incremental_learning(self, args, numclass, glob_centers):
        weight = self.fc.weight.data
        # bias = self.fc.bias.data
        in_feature = self.fc.in_features
        self.old_class_num = numclass - len(glob_centers)

        self.fc = nn.Linear(in_feature, numclass, bias=False)
        self.fc.weight.data[:self.old_class_num] = weight[:self.old_class_num]
        glob_centers = torch.tensor(glob_centers)
        self.fc.weight.data[self.old_class_num:] = glob_centers.clone()
        self.fc.requires_grad = True

    def feature_extractor(self, inputs):
        fea = self.feature(inputs)
        return fea

    def predict(self, fea_input):
        return self.fc(fea_input)

    def old_head(self, out):
        return out[:, :self.old_class_num]
    def new_head(self, out):
        return out[:, self.old_class_num:]
    def fea_encode(self, out):
        return self.encoder(out)


class LeNet(nn.Module):
    def __init__(self, channel=3, hideen=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(hideen, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def weights_init(m):
    try:
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.weight' % m._get_name())
    try:
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)
    except Exception:
        print('warning: failed in weights_init for %s.bias' % m._get_name())
