import os
import random
import re

from pathlib import Path

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from tqdm import tqdm

from backbone.model_resnet import ResNet_50, ResNet_18
from backbone.MobileFaceNets import MobileFaceNet
from data import IdentificationFaceRace, VerificationFaceRace

def l2_norm(input, axis = 1):
    # normalizes input with respect to second norm
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
random.seed(222)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

input_size = [112, 112]
emb_size = 512

test_transform = transforms.Compose([
        transforms.Resize([128, 128]),
        transforms.CenterCrop([112, 112]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])])

backbone_dict = {'ResNet_18': ResNet_18(input_size),
                 'MobileFaceNet': MobileFaceNet(embedding_size=emb_size, out_h=7, out_w=7),
                 'ResNet_50': ResNet_50(input_size)}

ident_data = IdentificationFaceRace("csvs/identification_questions.csv", "aligned", test_transform)
verif_data = VerificationFaceRace("csvs/verification_questions.csv", "aligned", test_transform)
identification_dataloader = DataLoader(ident_data)
verification_dataloader = DataLoader(verif_data)

os.makedirs("csvs/identification", exist_ok=True)
os.makedirs("csvs/verification", exist_ok=True)
os.makedirs("csvs/predictions", exist_ok=True)

for backbone_name, backbone in backbone_dict.items():
    backbone = backbone_dict[backbone_name]
    
    for checkpoint_path in Path('models').glob(f"*{backbone_name}*.pth"):
        method = re.search(backbone_name + "_([a-zA-Z]{7})", str(checkpoint_path)).groups()[0]
        preds_df = pd.DataFrame(columns=['id', 'qtype', 'prediction'])
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        backbone.load_state_dict(checkpoint)
        backbone = backbone.to(device)
        backbone.eval()
        
        print(("-" * 40) + f" On ({backbone_name}, {method}): identification " + ("-" * 40))
        
        ident_df = pd.DataFrame(columns=['id', 'qtype', 'cos_sim', 'l2_dist', 'truth'])
        
        for qid, images, label in tqdm(identification_dataloader):
            images = images.squeeze(0)
            flipped_images = torch.flip(images, (3,))
            embed = backbone(images) + backbone(flipped_images)
            features_batch = l2_norm(embed)
            target = features_batch[0, :]
            gallery = features_batch[1:, :]
            preds_df.loc[len(preds_df)] = [qid.item(), 'identification', (gallery @ target).argmax().item() + 1]
            ident_df.loc[len(ident_df)] = [qid.item(), 'identification', (gallery @ target).argmax().item() + 1, torch.cdist(target.unsqueeze(0), gallery).squeeze(0).argmin().item() + 1, label.item()]
        
        ident_df.to_csv(f'csvs/identification/{backbone_name}-{method}.csv', index=False)
        
        cos_accuracy = (ident_df['cos_sim'].to_numpy() == ident_df['truth'].to_numpy()).sum() / len(ident_df)
        l2_accuracy = (ident_df['l2_dist'].to_numpy() == ident_df['truth'].to_numpy()).sum() / len(ident_df)
        print(f"[Cosine Similarity] Accuracy: {cos_accuracy*100:.2f}")
        print(f"[L2 Distance] Accuracy: {l2_accuracy*100:.2f}\n")
        
        print(("-" * 40) + f"On ({backbone_name}, {method}): verification ------------------" + ("-" * 40))
        
        verif_df = pd.DataFrame(columns=['id', 'qtype', 'cos_sim', 'l2_dist', 'truth'])
        
        for qid, images, positive in tqdm(verification_dataloader):
            images = images.squeeze(0)
            flipped_images = torch.flip(images, (3,))
            embed = backbone(images) + backbone(flipped_images)
            features_batch = l2_norm(embed)
            image_1 = features_batch[0]
            image_2 = features_batch[1]
            preds_df.loc[len(preds_df)] = [qid.item(), 'verification', (image_1 @ image_2).item()]
            verif_df.loc[len(verif_df)] = [qid.item(), 'verification', (image_1 @ image_2).item(), image_1.dist(image_2).item(), positive.item()]
        
        verif_df.to_csv(f'csvs/verification/{backbone_name}-{method}.csv', index=False)
        
        logits = verif_df['cos_sim'].to_numpy()
        labels = verif_df['truth'].to_numpy().astype(int)
        
        best_split_index = np.array([((logits > split_val).astype(int) == labels).sum() for split_val in logits]).argmax()
        split_val = logits[best_split_index]
        
        verif_df['cos_pred'] = (logits > split_val).astype(int)
        preds_df['prediction'].loc[preds_df['qtype'] == 'verification'] = (logits > split_val).astype(int)
        
        accuracy = (verif_df['cos_pred'].to_numpy() == labels).sum() / len(logits)
        print(f"[Cosine Similarity] Best accuracy: {accuracy*100:.2f}, 0 if Value <= {split_val:.2f} else 1")
        
        logits = verif_df['l2_dist'].to_numpy()
        labels = verif_df['truth'].to_numpy().astype(int)
        
        best_split_index = np.array([((logits <= split_val).astype(int) == labels).sum() for split_val in logits]).argmax()
        split_val = logits[best_split_index]
        
        verif_df['l2_pred'] = (logits <= split_val).astype(int)
        
        accuracy = (verif_df['l2_pred'].to_numpy() == labels).sum() / len(logits)
        print(f"[L2 Distance] Best accuracy: {accuracy*100:.2f}, 1 if Value <= {split_val:.2f} else 0\n")
        
        preds_df.to_csv(f'csvs/predictions/{backbone_name}-{method}.csv', index=False)
