
import os
import torch
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import random
import pandas as pd
from joblib import dump, load
import argparse
import warnings
warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument('--target', nargs='+', default=['image'])
parser.add_argument('--base', default='ViT-B/32', type=str)
parser.add_argument('--image_num', default=50, type=int)
parser.add_argument('--text_num', default=50, type=int)
parser.add_argument('--t', default=0.7, type=float)
args = parser.parse_args()

clip_model = args.base
clip_name = clip_model.replace("/", "").replace('-', '')
threshold = args.t
device = "cuda" if torch.cuda.is_available() else "cpu"
from evaluation import evaluate_facet, evaluate_flickr

seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmarks=False
os.environ['PYTHONHASHSEED'] = str(seed)

img_important_indices = None
img_mean_features_misclassified = None


text_important_indices = None
text_mean_features_misclassified = None 

if 'image' in args.target:
    print("Debias Image Encoder")    
    embedding = np.load(f'data/fairface_{clip_name}_train.npz')
    embedding_val = np.load(f'data/fairface_{clip_name}_val.npz')
    
    X_train = embedding['X_train']
    y_train = embedding['y_train']
    X_test = embedding_val['X_test']
    y_test = embedding_val['y_test']

    img_model_path = f'weight/{clip_name}_img_random_forest_model.joblib'
    if os.path.exists(img_model_path):
        img_clf = load(img_model_path)
        print("Load pretrained Random Forest.")
    else : 
        img_clf = RandomForestClassifier(n_estimators=100)
        img_clf.fit(X_train, y_train)
        dump(img_clf,img_model_path)
    probabilities = img_clf.predict_proba(X_test)
    max_probabilities = probabilities.max(axis=1)
    low_confidence_samples = X_test[max_probabilities <threshold]
    embedding_dim = X_test.shape[1]
    mean_features_misclassified = torch.mean(torch.tensor(low_confidence_samples).float(),axis=0)
    importances = img_clf.feature_importances_
    
    results = []   
    pruning_num = int(args.image_num)
    img_important_indices = np.argsort(importances)[-pruning_num:] 
    img_important_indices = torch.tensor(img_important_indices).to(device)
    img_mean_features_misclassified = torch.tensor(mean_features_misclassified).to(device)

if 'text' in args.target:
    print("Debias Text Encoder")
    embedding = np.load(f'data/bios_{clip_name}_train.npz')
    embedding_val = np.load(f'data/bios_{clip_name}_val.npz')

    X_train = embedding['X_train']
    y_train = embedding['y_train']
    X_test = embedding_val['X_test']
    y_test = embedding_val['y_test']

    text_model_path = f'weight/{clip_name}_text_random_forest_model.joblib'
    if os.path.exists(text_model_path):
        text_clf = load(text_model_path)
        print("Load pretrained Random Forest.")
    else : 
        text_clf = RandomForestClassifier(n_estimators=100)
        text_clf.fit(X_train, y_train)
        dump(text_clf,text_model_path)
    probabilities = text_clf.predict_proba(X_test)
    max_probabilities = probabilities.max(axis=1)
    low_confidence_samples = X_test[max_probabilities < threshold]    
    importances = text_clf.feature_importances_
    embedding_dim = X_test.shape[1]
    text_mean_features_misclassified = torch.mean(torch.tensor(low_confidence_samples).float(),axis=0)    
    pruning_num = int(args.text_num)
    text_important_indices = np.argsort(importances)[-pruning_num:] 
    text_important_indices = torch.tensor(text_important_indices).to(device)
    text_mean_features_misclassified = torch.tensor(text_mean_features_misclassified).to(device)

print("*"*10, f"Evaluate Flickr","*"*10 )
recall1, recall5, recall10, flickrskew =evaluate_flickr(args,clip_model,device,clip_name,img_important_indices=img_important_indices,img_mean_features_misclassified=img_mean_features_misclassified\
                ,text_important_indices=text_important_indices,text_mean_features_misclassified=text_mean_features_misclassified)

print("*"*10, f"Evaluate FACET","*"*10 )
acc, mean_dp = evaluate_facet(args,clip_model,device,clip_name,img_important_indices=img_important_indices,img_mean_features_misclassified=img_mean_features_misclassified\
                ,text_important_indices=text_important_indices,text_mean_features_misclassified=text_mean_features_misclassified)
