import torch
import numpy as np
from dataset import CUB, Cars3D, RaFD, CelebA
from tqdm import tqdm
from torchvision.models import resnet18
import argparse
import os

parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset', default='cars3d')
parser.add_argument('--setting', default='multi')
parser.add_argument('--method', default=None)
parser.add_argument('--attribute', default=0, type=int)
args = parser.parse_args()


class INModel(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.model = resnet18(weights='IMAGENET1K_V1', progress=False)
        self.model.fc = torch.nn.Identity()

    def forward(self, x):
        # features = self.model.encode_image(x).float()
        logits, features = self.model(x)#.float()
        # return features
        return logits[:, 0], features

for value in range(3):
    if args.dataset == 'cars3d':
        dataset = Cars3D(split='train', value=0, attribute=0, all=True)
    elif args.dataset == 'rafd':
        dataset = RaFD(split='train', value=0, attribute=0, all=True)
    elif args.dataset == 'celeba':
        dataset = CelebA(split='train', value=0, attribute=0, all=True)
    loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = INModel(device)
    # model = Model(device)
    if args.method is not None:
        state_dict = torch.load('./checkpoints/{}/{}/{}/a{}_v{}.pth'
                                .format(args.dataset, args.setting, args.method, args.attribute, value))
        msg = model.load_state_dict(state_dict, strict=False)
        # print(msg)
    model = model.to(device)
    model.eval()
    features = []
    with torch.no_grad():
        for (imgs, labels) in tqdm(loader, desc='Train set feature extracting'):
            imgs = imgs.to(device)
            # features.append(model(imgs))
            features.append(model(imgs)[1])
        features = torch.cat(features, dim=0).contiguous().cpu().numpy()
    print(features.shape)
    if not os.path.exists('./features/{}/{}/'.format(args.dataset, args.setting)):
        os.makedirs('./features/{}/{}/'.format(args.dataset, args.setting))
    if args.method is None:
        np.save('./features/{}/{}/imagenet_resnet18.npy'.format(args.dataset, args.setting), features)
        print('ALL')
        exit()
    else:
        if not os.path.exists('./features/{}/{}/{}/'.format(args.dataset, args.setting, args.method)):
            os.makedirs('./features/{}/{}/{}/'.format(args.dataset, args.setting, args.method))
        np.save('./features/{}/{}/{}/a{}_v{}.npy'.format(args.dataset, args.setting, args.method, args.attribute, value), features)

