import os
import time
import pickle
import numpy as np
import random
from random import shuffle
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics.pairwise import cosine_similarity


class APIVR(nn.Module):
    def __init__(self, dim_i, dim_v):
        super(APIVR, self).__init__()
        self.i_mapper = Mapper(dim_i, 'i')
        self.v_mapper = Mapper(dim_v, 'v')

    def forward(self, i, v):
        return self.i_mapper(i), self.v_mapper(v)


class Mapper(nn.Module):
    def __init__(self, dim, modal):
        super(Mapper, self).__init__()
        assert modal == 'i' or modal == 'v'
        if modal == 'i':
            self.mapper = nn.Sequential(
                nn.Linear(128, 100),
                nn.Linear(100, 80),
                nn.Linear(80, 64),
            )
        if modal == 'v':
            self.mapper = nn.Sequential(
                nn.Linear(4096, 500),
                nn.Linear(500, 200),
                nn.Linear(200, 64),
            )

    def forward(self, feature):
        return self.mapper(feature)


if __name__ == "__main__":
    from thop import profile

    # Define your model
    net = APIVR(128, 4096)

    # Prepare dummy inputs
    inputs1 = torch.randn(1, 128)
    inputs2 = torch.randn(100, 4096)

    # Profile the model
    flops, params_count = profile(net, inputs=(inputs1, inputs2,))
    print('FLOPs = {:.2f}G'.format(2 * flops / 1000**3))
    print('Params = {:.2f}M'.format(params_count / 1000**2))
