import sys
sys.path.insert(0, ROOT_PATH)
import os
import argparse

import numpy as np
import torch

from wilds import get_dataset
from openai.embeddings_utils import cosine_similarity

import comnivore.const as const
from sklearn.metrics import accuracy_score

from scipy.spatial.distance import cosine
from tqdm import tqdm

def eval_wilds(preds, test_Y, metadata):
    dataset = get_dataset(dataset=dataset_name, download=True, root_dir=DATA_DIR)
    _, results_str = dataset.eval(preds, test_Y, metadata)
    return results_str

def eval_hatexplain(preds, test_Y, metadata):
    print('ALL', len(test_Y))
    if torch.is_tensor(preds):
        preds = preds.detach().cpu().numpy()
    if torch.is_tensor(test_Y):
        test_Y = test_Y.detach().cpu().numpy().astype(int)
    if torch.is_tensor(metadata):
        metadata = metadata.detach().cpu().numpy()
    acc_all = []
    for i in range(metadata.shape[1]):
        for y in np.unique(test_Y):
            group_idx = np.argwhere((metadata[:, i] == 1) & (test_Y == y))
            if len(group_idx) > 4:
                print(len(group_idx))
                preds_group = preds[group_idx]
                y_true_group = test_Y[group_idx]
                acc_group = accuracy_score(y_true_group, preds_group)
                acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str

def eval_gender(preds, test_Y, metadata):
    print('ALL', len(test_Y))
    if torch.is_tensor(preds):
        preds = preds.detach().cpu().numpy()
    if torch.is_tensor(test_Y):
        test_Y = test_Y.detach().cpu().numpy()
    if torch.is_tensor(metadata):
        metadata = metadata.detach().cpu().numpy()
    acc_all = []
    for y in np.unique(test_Y):
        group_idx = np.argwhere((test_Y==y)).flatten()
        if len(group_idx) > 0:
            preds_group = preds[group_idx]
            y_true_group = test_Y[group_idx]
            acc_group = accuracy_score(y_true_group, preds_group)
            print(len(group_idx))
            if acc_group == 0:
                print('zero acc', len(group_idx))
            acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    # print(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str

def eval_amazon(preds, test_Y, metadata):
    print('ALL', len(test_Y))
    if torch.is_tensor(preds):
        preds = preds.detach().cpu().numpy()
    if torch.is_tensor(test_Y):
        test_Y = test_Y.detach().cpu().numpy().astype(int)
    if torch.is_tensor(metadata):
        metadata = metadata.detach().cpu().numpy()
    acc_all = []
    # for label_ in np.unique(test_Y):
    for m in np.unique(metadata[:, 1]):
        group_idxs = np.argwhere(metadata[:, 2] == m).flatten()
        if len(group_idxs) == 0:
            continue
        group_preds = preds[group_idxs]
        group_y = test_Y[group_idxs]
        print(len(group_y))
        acc_group = accuracy_score(group_preds, group_y)
        acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str

def label_score(text_embedding, label_embeddings):
    return cosine_similarity(text_embedding, label_embeddings[1, :])-cosine_similarity(text_embedding, label_embeddings[0, :])
    
def get_preds_cos(emb_all, label_emb):
    scores = []
    for i in range(emb_all.shape[0]):
        emb = emb_all[i, :]
        score = label_score(emb, label_emb)
        scores.append(score)
    scores = np.vstack(scores)
    y_pred = np.zeros(len(scores))
    y_pred[np.argwhere(scores > 0).flatten()] = 1
    y_pred = torch.Tensor(y_pred)
    return y_pred, scores

def get_preds_simcse(emb_all, label_emb):
    def pred_simcse(text_embedding, label_embeddings):
        cos_all = []
        for i in range(label_embeddings.shape[0]):
            cosine_sim = cosine(text_embedding, label_embeddings[i, :])
            cos_all.append(cosine_sim)
        cos_all = np.array(cos_all)
        return np.argmax(cos_all).flatten()[0]
    y_pred_all = []
    for i in range(emb_all.shape[0]):
        emb = emb_all[i, :]
        y_pred_all.append(pred_simcse(emb, label_emb))
    y_pred_all = torch.Tensor(y_pred_all)
    return y_pred_all, []

def get_preds_hf(emb_all, label_emb):
    def pred_hf(text_embedding, label_embeddings):
        similarity = text_embedding @ label_embeddings.T    
        return np.argmax(similarity).flatten()[0]
    y_pred_all = []
    for i in range(emb_all.shape[0]):
        emb = emb_all[i, :]
        y_pred_all.append(pred_hf(emb, label_emb))
    y_pred_all = torch.Tensor(y_pred_all)
    return y_pred_all, []

def remove_error(text_emb, y_true, metadata):
    error_batch_idx = 1148
    batch_size = 64
    batch_sample_start = error_batch_idx * batch_size
    text_emb = np.vstack((text_emb[:batch_sample_start], text_emb[batch_sample_start+64:]))
    y_true = torch.hstack((y_true[:batch_sample_start], y_true[batch_sample_start+64:]))
    metadata = torch.vstack((metadata[:batch_sample_start], metadata[batch_sample_start+64:]))
    return text_emb, y_true, metadata

def get_class_balance(test_Y):
    if torch.is_tensor(test_Y):
        test_Y = test_Y.detach().cpu().numpy().astype(int)
    p_0 = len(test_Y[test_Y==0])/len(test_Y)
    p_1 = len(test_Y[test_Y==1])/len(test_Y)
    return p_0, p_1

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-d', '--dataset', type=str, default='civilcomments')
    parser.add_argument('-m', '--model', type=str, default='hf_sim')
    args = parser.parse_args()

    dataset_name = args.dataset
    model = args.model
    root_dir = f'{dataset_name}_features/'
    load_dir = f'{dataset_name}_features/features_{model}'

    subdirs = [os.path.join(load_dir, p) for p in os.listdir(load_dir) if (os.path.isdir(os.path.join(load_dir, p))) and ('text' not in p)]

    eval_fn_dict = {
        const.CIVILCOMMENTS_NAME: eval_wilds,
        const.HATEXPLAIN_NAME: eval_hatexplain,
        const.AMAZON_NAME: eval_amazon,
        const.GENDER_BIAS_NAME: eval_gender,
    }

    pred_fn = {
        'simcse': get_preds_simcse,
        'hf_sim': get_preds_hf,
        'openai': get_preds_cos,
    }
    pred = pred_fn[model]
    eval_fn = eval_fn_dict[dataset_name]
    for i, subdir in enumerate(subdirs):
        label_emb = np.load(os.path.join(load_dir, 'labels.npy'))
        text_emb = np.load(os.path.join(subdir, 'emb.npy'))
        
        y_true = torch.Tensor(np.load(os.path.join(subdir, 'y.npy')))
        metadata = np.load(os.path.join(subdir, 'metadata.npy'))

        if dataset_name == const.AMAZON_NAME:
            y_true = y_true.squeeze()
            text_emb = text_emb.squeeze()
            take_idx = np.argwhere(y_true != 2).flatten()
            text_emb = text_emb[take_idx, :]
            y_true = y_true[take_idx]
            metadata = metadata[take_idx, :]
            y_true[np.argwhere((y_true ==0)|(y_true==1)).flatten()] = 0
            y_true[np.argwhere((y_true ==3)|(y_true==4)).flatten()] = 1
        
        metadata = torch.Tensor(metadata)
        y_pred, _ = pred(text_emb, label_emb)

        print("========= Baseline ZS =========")
        results = eval_fn(y_pred, y_true, metadata)
        print(results)
        exit()
        text_subdirs = [os.path.join(load_dir, p) for p in os.listdir(load_dir) if 'text' in p]
        for t_idx, t_sub in enumerate(text_subdirs):
            if 'projected_emb_single.npy'not in os.listdir(t_sub):
                continue
            print(t_sub)
            projected_emb = np.load(os.path.join(t_sub, 'projected_emb_single.npy'))
            if dataset_name == const.AMAZON_NAME:
                projected_emb = projected_emb[take_idx, :]
            y_pred, _ = pred(projected_emb, label_emb)
            print(f"========= Ours =========")
            results = eval_fn(y_pred, y_true, metadata)
            print(results)