import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from .utils.weights_init import weights_init_classifier, weights_init_kaiming
from .model_base_f import BasicModel

class CosClassifier(nn.Module):
    def __init__(self, embedding_size=1024, num_class=51332, **model_kwargs):
        super(CosClassifier, self).__init__()
        self.embedding_size = embedding_size
        self.classnum = num_class
        self.fc = nn.Linear(embedding_size, self.classnum, bias=False)
        weights_init_classifier(self.fc)
    
    def forward(self, embeddings):
        epsilon=1e-8
        kernel_norm = F.normalize(self.fc.weight, p=2, dim=1)
        n_embeddings = F.normalize(embeddings, p=2, dim=-1)
        cos_theta = torch.mm(n_embeddings, kernel_norm.t())
        cos_theta = cos_theta.clamp(-1+epsilon, 1-epsilon)
        return cos_theta


class FModel(nn.Module):
    def __init__(self, embedding_size=1024, extract_feature_model='pure', **model_kwargs):
        super(FModel, self).__init__()
        '''
        NOTE:
        embedding_size: for softmax loss, embedding size is the class number, for arcface, maybe 2048
        extract_feature_model: 'pure', 'pcb', 'fpn_f'
        '''
        self.embedding_size = embedding_size
        assert extract_feature_model in ['pure', 'pcb', 'fpn_f'], 'Please check the F model name!'
        if extract_feature_model == 'pure':
            name = model_kwargs['name']
            self.f_model = BasicModel(name=name)
            self.feature_in_dim = self.f_model.feature_in_dim

        # use for supervised
        # if self.train_type == 'supervised':
        # self.f_global_conv = nn.Conv2d(self.feature_in_dim, self.feature_in_dim, kernel_size=16, bias=False)
        self.f_global_bn = nn.BatchNorm2d(self.feature_in_dim).apply(weights_init_kaiming)
        self.relu = nn.RReLU(inplace=True)
        self.gelu = nn.GELU()
        self.f_avg_pool = F.avg_pool2d
        self.fc_combine = nn.Linear(self.feature_in_dim, self.feature_in_dim)
        self.bn_final = nn.BatchNorm1d(self.feature_in_dim).apply(weights_init_kaiming)
        weights_init_kaiming(self.fc_combine)
        if self.embedding_size != 0:
            self.f_mlp_fc_1 = nn.Linear(self.feature_in_dim, 4096)
            self.bn_1 = nn.BatchNorm1d(4096).apply(weights_init_kaiming)
            self.f_mlp_fc_2 = nn.Linear(4096, self.embedding_size)
            self.bn_2 = nn.BatchNorm1d(self.embedding_size).apply(weights_init_kaiming)
            self.feature_in_dim = self.embedding_size
            weights_init_classifier(self.f_mlp_fc_1)
            weights_init_classifier(self.f_mlp_fc_2)


    def forward(self, batch_imgs):
        # shape: B C H W
        x = batch_imgs
        raw_img_feature = self.f_model(x)
        embedding_feature = None
        # embeddiung feature to calculate the l1 loss
        # ==> use avg pool
        # B feature_dim H W -> B feature_dim 1 1
        y = self.f_avg_pool(raw_img_feature, raw_img_feature.size()[2:])
        compress_feature = self.f_global_bn(y)
        embedding_feature = compress_feature.view(y.size()[:2])


        return embedding_feature, compress_feature