import sys
sys.path.insert(0, ROOT_PATH)

import os

from transformers import pipeline
import argparse
from tqdm import tqdm
from comnivore.dataloader import MultiEnvDataset
import numpy as np

import torch
from sklearn.metrics import accuracy_score
from wilds import get_dataset

import comnivore.const as const
from text_prompts import text_prompts

pipe = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0)

def eval_wilds(preds, test_Y, metadata):
    if not torch.is_tensor(preds):
        preds = torch.Tensor(preds)
    if not torch.is_tensor(test_Y):
        test_Y = torch.Tensor(test_Y)
    if not torch.is_tensor(metadata):
        metadata = torch.Tensor(metadata)
    dataset = get_dataset(dataset=dataset_name, download=True, root_dir=DATA_DIR)
    _, results_str = dataset.eval(preds, test_Y, metadata)
    return results_str

def eval_hatexplain(preds, test_Y, metadata):
    acc_all = []
    for i in range(metadata.shape[1]):
        for y in np.unique(test_Y):
            group_idx = np.argwhere((metadata[:, i] == 1) & (test_Y == y))
            if len(group_idx) > 4:
                preds_group = preds[group_idx]
                y_true_group = test_Y[group_idx]
                acc_group = accuracy_score(y_true_group, preds_group)
                if acc_group == 0:
                    print('acc 0', len(group_idx))
                acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    print(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str

def eval_gender(preds, test_Y, metadata):
    print(np.unique(preds))
    acc_all = []
    for i in np.unique(metadata):
        group_idx = np.argwhere(metadata==i).flatten()
        if len(group_idx) > 0:
            preds_group = preds[group_idx]
            y_true_group = test_Y[group_idx]
            acc_group = accuracy_score(y_true_group, preds_group)
            if acc_group == 0:
                print('acc 0', len(group_idx))
            acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    print(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str

def eval_amazon(preds, test_Y, metadata):
    acc_all = []
    # for label_ in np.unique(test_Y):
    for m in np.unique(metadata[:, 1]):
        group_idxs = np.argwhere(metadata[:, 2] == m).flatten()
        if len(group_idxs) == 0:
            continue
        group_preds = preds[group_idxs]
        group_y = test_Y[group_idxs]
        acc_group = accuracy_score(group_preds, group_y)
        acc_all.append(acc_group)
    acc_all = np.array(acc_all)
    results_str = f'average acc = {np.mean(acc_all):.3f}' + '\n' + f'worst-group acc = {np.amin(acc_all):.3f}'
    return results_str


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-d', '--dataset', type=str, default='civilcomments')

    args = parser.parse_args()
    dataset_name = args.dataset

    eval_fn_dict = {
        const.CIVILCOMMENTS_NAME: eval_wilds,
        const.HATEXPLAIN_NAME: eval_hatexplain,
        const.AMAZON_NAME: eval_amazon,
        const.GENDER_BIAS_NAME: eval_gender,
    }
    eval_fn = eval_fn_dict[dataset_name]

    dataloaders = MultiEnvDataset().get_dataloaders(dataset_name, 1)
    labels = text_prompts[dataset_name]['labels_bart']
    y_true_all = []
    y_pred_all = []
    metadata_all = []
    for i, dataloader in enumerate(dataloaders):
        embeddings_all = []
        y_all = []
        metadata_all = []
        for j, labeled_batch in tqdm(enumerate(dataloader)):
            try:
                if len(labeled_batch) == 3:
                    x, y_true, metadata = labeled_batch
                    metadata = metadata.detach().cpu().numpy().tolist()
                    y_true = y_true.detach().cpu().numpy().tolist()
                else:
                    x, y_true = labeled_batch
                out = pipe(x,candidate_labels=labels,device=0)
                y_pred = np.argwhere(np.array(labels) == out['labels'][0]).flatten()[0]
                y_true_all.extend(y_true)
                y_pred_all.append(y_pred)
                metadata_all.append(metadata[0])
            except:
                continue
        #     if j == 10:
        #         break
        # if j == 10:
        #     break
    y_pred_all = np.array(y_pred_all)
    y_true_all = np.array(y_true_all)
    metadata_all = np.array(metadata_all)
    store_dir = os.path.join('bart_preds', dataset_name)
    if not os.path.isdir(store_dir):
        os.makedirs(store_dir)
    np.save(os.path.join(store_dir, 'y_pred.npy'), y_pred_all)
    np.save(os.path.join(store_dir, 'y_true.npy'), y_true_all)
    np.save(os.path.join(store_dir, 'metadata.npy'), metadata_all)
    print(eval_fn(y_pred_all, y_true_all, metadata_all))
    