import sys
sys.path.insert(0, ROOT_PATH)
import os
import argparse

from shared import project_equidistant, project_equidistant_multi

import numpy as np
import torch

from wilds import get_dataset
from snorkel.labeling.model import LabelModel


from dataloader import MultiEnvDataset
from get_clip_text_emb import get_text_embedding
from chatgpt_reprompting import get_z_prompts
from openLM_reprompting import get_z_prompts_openLM
import const
from UWS import ContinuousLabelModel
from tqdm import tqdm

from sklearn.metrics import accuracy_score, f1_score

torch.cuda.set_device(1)

def eval_wilds(preds, test_Y):
    if not torch.is_tensor(test_Y):
        test_Y = torch.Tensor(test_Y)
    metadata = np.load(os.path.join(load_dir, 'metadata.npy'))
    metadata = torch.Tensor(metadata)
    dataset = get_dataset(dataset=dataset_name, download=False, root_dir=DATA_DIR)
    _, results_str = dataset.eval(preds, test_Y, metadata)
    print(results_str)
    return results_str

def eval_domainbed(y_pred, y_true, logits):
    dataset = MultiEnvDataset().dataset_dict[dataset_name]()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    metadata = np.load(os.path.join(load_dir, 'metadata.npy'))
    if len(metadata.shape) > 1:
        metadata = metadata.flatten()
    unique_domains = np.unique(metadata)
    acc_all = []
    for domain in unique_domains:
        for y in np.unique(y_true):
            d_sample_idx = np.argwhere((metadata== domain) & (y_true==y))
            if len(d_sample_idx) == 0:
                continue
            samples_y_pred = y_pred[d_sample_idx]
            samples_y_true = y_true[d_sample_idx]
            domain_acc = accuracy_score(samples_y_true, samples_y_pred)
            acc_all.append(domain_acc)
            # print(f'{dataset.metadata_map_reverse[domain]} y = {y}: {domain_acc:.3f}')
    acc_all = np.array(acc_all)
    print(f"AVG acc = {np.mean(acc_all):.3f}")
    print(f"WORST group acc = {np.amin(acc_all):.3f}")
    print('\n')

def eval_synthetic(y_pred, y_true, logits):
    dataset = MultiEnvDataset().dataset_dict[dataset_name]()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    metadata = np.load(os.path.join(load_dir, 'metadata.npy'))
    print('per domain acc')
    unique_domains = np.unique(metadata)
    acc_all = []
    idx2class = {v:k for k,v in dataset.class_to_idx.items()}
    for domain in unique_domains:
        for y in np.unique(y_true):
            d_sample_idx = np.argwhere((metadata== domain) & (y_true == y))
            samples_y_pred = y_pred[d_sample_idx]
            samples_y_true = y_true[d_sample_idx]
            domain_acc = accuracy_score(samples_y_true, samples_y_pred)
            print(f'{dataset.metadata_map_reverse[domain]} y = {idx2class[y]} acc: {domain_acc:.3f}')
            acc_all.append(domain_acc)
    print(f'Average acc = {np.mean(acc_all):.3f}')
    print(f'Worst acc = {np.amin(acc_all):.3f}')
    if torch.is_tensor(logits):
        logits = logits.detach().cpu().numpy()
    f1 = f1_score(y_true, y_pred, average='macro')
    print(f'F1 = {f1:.3f}')
    print('\n')
    
def eval_cxr(y_pred, y_true, logits):
    if torch.is_tensor(logits):
        logits = logits.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    acc_all = []
    for y in np.unique(y_true):
        class_sample_idx = np.argwhere(y_true==y)
        group_acc = accuracy_score(y_true[class_sample_idx], y_pred[class_sample_idx])
        acc_all.append(group_acc)
    acc_all = np.array(acc_all)
    print(f'avg acc = {np.mean(acc_all):.3f}')
    print(f'wg acc = {np.amin(acc_all):.3f}')
    print('\n')


def make_clip_preds(image_features, text_features):
    if not torch.is_tensor(image_features):
        image_features = torch.Tensor(image_features)
    if not torch.is_tensor(text_features):
        text_features = torch.Tensor(text_features)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    return torch.argmax(text_probs, dim=1), text_probs

def group_prompt_preds(raw_preds):
    raw_preds = raw_preds.detach().cpu().numpy()
    raw_preds[np.argwhere((raw_preds == 0) | (raw_preds == 1)).flatten()] = 0
    raw_preds[np.argwhere((raw_preds == 2) | (raw_preds == 3)).flatten()] = 1
    return torch.Tensor(raw_preds)

def group_prompt_preds_multi(raw_preds, n_full_prompt, n_prompt_per_class, n_class):
    c_idx = 0
    for p_idx in range(0, n_full_prompt, n_prompt_per_class):
        idxs = []
        for cp_idx in range(n_prompt_per_class):
            idxs.extend(np.argwhere(raw_preds == p_idx + cp_idx).flatten())
        raw_preds[idxs] = c_idx
        c_idx +=1
    return torch.Tensor(raw_preds)

def generate_z_prompts_square(z_prompts):
    z_prompts_square = []
    visited = set()
    for i, p1 in enumerate(z_prompts):
        for j, p2 in enumerate(z_prompts):
            if p1 == p2:
                continue
            if (i,j) in visited:
                continue
            z_prompts_square.append([p1[0], p2[1]])
            z_prompts_square.append([p1[1], p2[0]])
            visited.add((i,j))
            visited.add((j,i))
    z_prompts_square.extend(z_prompts)
    return z_prompts_square

def evaluate(dataset_name, preds, test_Y, logits):
    eval_func = {
        const.WATERBIRDS_NAME: eval_wilds,
        const.CELEBA_NAME: eval_wilds,
        const.PACS_NAME: eval_domainbed,
        const.SD_CATDOG_NAME: eval_synthetic,
        const.SD_NURSE_FIREFIGHTER_NAME: eval_synthetic,
        const.CXR_NAME: eval_cxr,
        const.BREEDS17_NAME: eval_domainbed,
        const.BREEDS26_NAME: eval_domainbed,
    }
    if dataset_name not in [const.IMAGENETS_NAME, const.CXR_NAME, const.PACS_NAME, const.BREEDS17_NAME, const.BREEDS26_NAME]:
        eval_func[dataset_name](preds, test_Y)
    else:
        eval_func[dataset_name](preds, test_Y, logits)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-dataset', '--dataset_name', type=str, required=True)
    args = parser.parse_args()
    
    dataset_name = args.dataset_name
    max_tokens = 100
    n_paraphrases = 1
    if dataset_name == const.BREEDS17_NAME:
        max_tokens = 300
        n_paraphrases = 0
    elif dataset_name == const.BREEDS26_NAME:
        max_tokens = 750
        n_paraphrases = 0
    z_prompts = get_z_prompts(dataset_name, verbose=True, max_tokens=max_tokens, n_paraphrases=n_paraphrases)
    # z_prompts = get_z_prompts_openLM(dataset_name, model_name=const.LLAMA_NAME)
    labels_text = MultiEnvDataset().dataset_dict[dataset_name]().get_labels()
    dir_dict = {
        # const.CLIP_ALIGN_NAME: f'../{dataset_name}_features/features_gt_ALIGN/0',
        # const.CLIP_BASE_NAME: f'../{dataset_name}_features/features_gt/0',
        # const.CLIP_ALT_NAME: f'../{dataset_name}_features/features_gt_alt/0',
        # const.CLIP_BIOMED_NAME: f'../{dataset_name}_features/features_gt/0',
        # const.CLIP_OPEN_VITL14: f'../{dataset_name}_features/features_openclip_vitL14/0',
        const.CLIP_OPEN_VITB32: f'../{dataset_name}_features/features_openclip_vitB32/0',
        # const.CLIP_OPEN_VITH14: f'../{dataset_name}_feature÷s/features_openclip_vitH14/0',
        # const.CLIP_OPEN_RN50: f'../{dataset_name}_features/features_openclip_rn50/0',
    }

    for key in dir_dict:
        load_dir = dir_dict[key]
        x_avg = None
        ensemble_logs = None
        x_snorkel = []
        text_emb_all = []

        print(f'CLIP MODEL = {key}')
        test_X = np.load(os.path.join(load_dir, 'image_emb.npy'))
        test_Y = np.load(os.path.join(load_dir, 'y.npy'))

        
        label_emb = get_text_embedding(labels_text, model_name=key)
        preds, logits = make_clip_preds(test_X, label_emb)
        print("========= Baseline (ZS) =========")
        evaluate(dataset_name, preds, test_Y, logits)
        
        if dataset_name in [const.WATERBIRDS_NAME, const.CELEBA_NAME, const.PACS_NAME]:
            print("========= Baseline (Group Prompt) =========")
            group_prompt = MultiEnvDataset().dataset_dict[dataset_name]().get_group_prompts()
            group_prompt_emb = get_text_embedding(group_prompt, model_name=key)
            preds, logits = make_clip_preds(test_X, group_prompt_emb)
            if dataset_name == const.PACS_NAME:
                group_prompt_preds_multi(preds, len(group_prompt), 4, 7)
            else:
                preds = group_prompt_preds(preds)
            evaluate(dataset_name, preds, test_Y, logits)
       
        for t_idx, prompt in tqdm(enumerate(z_prompts)):
            text_emb = get_text_embedding(prompt, model_name=key)
            text_emb_all.append(text_emb)
            if text_emb.shape[0] <=2:
                test_proj = project_equidistant(text_emb, test_X)
            else:
                test_proj = project_equidistant_multi(text_emb, test_X)
            test_proj = torch.Tensor(test_proj)
            label_emb = torch.Tensor(label_emb)

            preds, logits = make_clip_preds(test_proj, label_emb)
            if t_idx == 0:
                x_avg = torch.clone(test_proj)
                ensemble_logs = torch.clone(logits)
            else:
                x_avg += test_proj
                for sample_idx in range(logits.shape[0]):
                    for class_idx in range(logits.shape[1]):
                        if logits[sample_idx, class_idx] > ensemble_logs[sample_idx, class_idx]:
                            ensemble_logs[sample_idx, class_idx] = logits[sample_idx, class_idx]
            preds_snorkel = preds.detach().cpu().numpy().reshape(-1,1)
            x_snorkel.append(preds_snorkel)
            try:
                print(f"========= Ours {t_idx +1} =========")
                evaluate(dataset_name, preds, test_Y, logits)
            except Exception as e:
                raise e
                continue
        test_Y = torch.Tensor(test_Y)

        # print(f"========= equidistant all =========")
        # text_emb_eq = np.vstack(text_emb_all)
        # if text_emb_eq.shape[0] > len(np.unique(text_emb_eq,axis=0)):
        #     text_emb_eq = np.unique(text_emb_eq,axis=0)
        # test_proj = project_equidistant_multi(text_emb_eq, test_X, even_contraint=True)
        # preds, logits = make_clip_preds(test_proj, label_emb)
        # evaluate(dataset_name, preds, test_Y, logits)
        
        text_emb_all = np.array(text_emb_all)

        preds = torch.argmax(ensemble_logs, dim=1)
        print("========= ENSEMBLE =========")
        evaluate(dataset_name, preds, test_Y, ensemble_logs)
        
        x_snorkel = np.hstack(x_snorkel)
        if x_snorkel.shape[1] >= 3:
            triplet_model = LabelModel(
                cardinality=len(np.unique(test_Y.detach().cpu().numpy())), verbose=False, 
            )
            triplet_model.fit(
                x_snorkel,
                n_epochs=100, seed=123, lr=1e-3,
                # class_balance=[.6,.4], 
                progress_bar=False
            )
            preds = triplet_model.predict(x_snorkel)
            proba = triplet_model.predict_proba(x_snorkel)
            preds = torch.Tensor(preds)
            print("========= SNORKEL =========")
            evaluate(dataset_name, preds, test_Y, proba)
        
        # --------- Rejecting all spurious directions ------------
        from sklearn.metrics.pairwise import cosine_similarity
        spurious_vectors = text_emb_all[:, 0, :] - text_emb_all[:, 1, :]
        
        # QR decomposition --> get orthonormal vectors of spurious feature vectors
        q, r = np.linalg.qr(spurious_vectors.T)
        q = q.T
        
        # Transform X so that so that it is orthogonal to all spurious directions
        test_proj = np.copy(test_X)
        test_proj = test_proj / np.linalg.norm(test_proj, axis=1).reshape(-1, 1)

        # Reject projections to those orthonormal vectors
        for orthonormal_vector in q:
            cos = np.squeeze(cosine_similarity(test_proj, orthonormal_vector.reshape(1, -1)))
            rejection_features = cos.reshape(-1, 1) * np.repeat(orthonormal_vector.reshape(1, -1), cos.shape[0], axis=0) / np.linalg.norm(orthonormal_vector)
            test_proj = test_proj + rejection_features
            test_proj = test_proj / np.linalg.norm(test_proj, axis=1).reshape(-1, 1)
        test_proj = torch.Tensor(test_proj)
        label_emb = torch.Tensor(label_emb)
        preds, logits = make_clip_preds(test_proj, label_emb)
        print("========= OURS W/ QR Rejection =========")
        evaluate(dataset_name, preds, test_Y, logits)