import torch
import numpy as np
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
from tqdm import tqdm
import torch.optim as optim
from dataset import CUB, Cars3D, RaFD, CelebA
import faiss
import argparse
import os
from torchvision.models import resnet18, ResNet18_Weights

class LinearModel(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.fc = torch.nn.Linear(512, 1)

    def forward(self, features):
        return self.fc(features)[:, 0], features


def run_epoch(model, train_loader, optimizer, criterion, device):
    total_loss, total_num = 0.0, 0
    for (imgs, labels) in train_loader:

        imgs = imgs.to(device)
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        # print(labels)

        optimizer.zero_grad()

        logits, _ = model(imgs)
        loss = criterion(logits, labels.float())

        loss.backward()

        optimizer.step()

        total_num += imgs.size(0)
        total_loss += loss.item() * imgs.size(0)

    return total_loss / (total_num)

def run_test(model, test_loader, device):
    with torch.no_grad():
        test_logits = []
        test_labels = []
        for (imgs, labels) in test_loader:
            imgs = imgs.to(device)
            test_labels.append(labels)
            logits, features = model(imgs)
            test_logits.append(logits)
        test_logits = torch.cat(test_logits, 0).detach().cpu().numpy()
        test_labels = torch.cat(test_labels, 0).detach().cpu().numpy()
        distances = test_logits
        auc_logits = roc_auc_score(test_labels, distances)
        return auc_logits


parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset', default='cars3d')
parser.add_argument('--setting', default='multi')
parser.add_argument('--method', default=None)
parser.add_argument('--attribute', default=0, type=int)
parser.add_argument('--value', default=0, type=int)
args = parser.parse_args()
print(args.dataset, args.setting, args.method, args.attribute)
attribute = args.attribute
num_of_values = 3 if num_of_values > 3 else num_of_values
aucs_fc = []
aucs_knn = []
aucs_knn_ft = []
for i in range(num_of_values):
    if args.dataset == 'cars3d':
        trainset_knn = Cars3D(split='train', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        testset_knn = Cars3D(split='test', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        trainset = Cars3D(split='train', value=i, attribute=attribute, setting=args.setting, method=args.method)
        testset = Cars3D(split='test', value=i, attribute=attribute, setting=args.setting, method=args.method)
    elif args.dataset == 'rafd':
        trainset_knn = RaFD(split='train', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        testset_knn = RaFD(split='test', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        trainset = RaFD(split='train', value=i, attribute=attribute, setting=args.setting, method=args.method)
        testset = RaFD(split='test', value=i, attribute=attribute, setting=args.setting, method=args.method)
    elif args.dataset == 'celeba':
        trainset_knn = CelebA(split='train', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        testset_knn = CelebA(split='test', value=i, attribute=attribute, knn=True, setting=args.setting, method=args.method)
        trainset = CelebA(split='train', value=i, attribute=attribute, setting=args.setting, method=args.method)
        testset = CelebA(split='test', value=i, attribute=attribute, setting=args.setting, method=args.method)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = LinearModel(device)
    model = model.to(device)
    model = model.eval()

    train_loader_knn = torch.utils.data.DataLoader(trainset_knn, batch_size=64, shuffle=True, num_workers=2,
                                               drop_last=False)
    test_loader_knn = torch.utils.data.DataLoader(testset_knn, batch_size=64, shuffle=False, num_workers=2,
                                              drop_last=False)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2,
                                               drop_last=False)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2,
                                              drop_last=False)
    print(len(trainset), len(testset))
    criterion = torch.nn.BCEWithLogitsLoss()

    with torch.no_grad():
        train_logits = []
        test_logits = []
        test_labels = []
        for (imgs, labels) in tqdm(train_loader_knn, desc='Train KNN...'):
            imgs = imgs.to(device)
            _, features = model(imgs)
            train_logits.append(features)
        for (imgs, labels) in tqdm(test_loader_knn, desc='Test KNN...'):
            imgs = imgs.to(device)
            _, features = model(imgs)
            test_logits.append(features)
            test_labels.append(labels)
        train_logits = torch.cat(train_logits, 0).detach().cpu().numpy()
        test_logits = torch.cat(test_logits, 0).detach().cpu().numpy()
        test_labels = torch.cat(test_labels, 0).detach().cpu().numpy()
    index = faiss.IndexFlatL2(train_logits.shape[1])
    index.add(train_logits.astype(np.float32))
    D, _ = index.search(test_logits.astype(np.float32), 10)
    distances = np.sum(D, axis=1)
    auc = roc_auc_score(test_labels, distances)
    print('Attribute: {}, Value: {}, Epoch: {}, AUROC: {}'.format(attribute, i, 0, auc * 100))
    aucs_knn.append(auc * 100)
    optimizer = optim.SGD(model.parameters(), lr=1e-2)
    for epoch in tqdm(range(100), desc='Epoch'):
        running_loss = run_epoch(model, train_loader, optimizer, criterion, device)
        if (epoch+1) % 1 == 0 or epoch == 99:
            auc_logits = run_test(model, test_loader, device)
            print('Attribute: {}, Value: {}, Epoch: {}, AUROC Logits: {}'.format(attribute, i, epoch + 1, auc_logits*100))
    aucs_fc.append(auc_logits * 100)
    print('Attribute: {}, Value: {}, knn: {}, fc: {}'.format(attribute, i, aucs_knn[-1], aucs_fc[-1]))

print('Attribute: {}, knn mean: {}, fc mean: {}'.format(attribute, np.mean(aucs_knn), np.mean(aucs_fc)))