import sys
sys.path.insert(0, ROOT_PATH)
from shared import project_equidistant, project_equidistant_multi

import numpy as np

import os
import argparse
from tqdm import tqdm
import comnivore.const as const
from sklearn.metrics.pairwise import cosine_similarity

def remove_error(text_emb, metadata):
    error_batch_idx = 1148
    batch_sample_start = (133782//1148)-1
    text_emb = np.vstack((text_emb[:batch_sample_start], text_emb[batch_sample_start+63:]))
    metadata = np.vstack((metadata[:batch_sample_start], metadata[batch_sample_start+64:]))
    return text_emb, metadata
    
def project_orthonormal(W, x):
    spurious_vectors = W[:, 0, :] - W[:, 1, :]
    q, r = np.linalg.qr(spurious_vectors.T)
    q = q.T
    q.shape
    # Transform X so that so that it is orthogonal to all spurious directions
    test_proj = np.copy(x)
    test_proj = test_proj / np.linalg.norm(test_proj)

    # Reject projections to those orthonormal vectors
    for orthonormal_vector in q:
        cos = cosine_similarity(test_proj.reshape(1,-1), orthonormal_vector.reshape(1,-1))
        rejection_features = cos.reshape(-1, 1) * np.repeat(orthonormal_vector.reshape(1, -1), cos.shape[0], axis=0) / np.linalg.norm(orthonormal_vector)
        test_proj = test_proj - rejection_features
        test_proj = test_proj / np.linalg.norm(test_proj).reshape(-1, 1)
    return test_proj

def project_orthonormal_all(W, X):
    spurious_vectors = W[:, 0, :] - W[:, 1, :]
    q, r = np.linalg.qr(spurious_vectors.T)
    q = q.T
    test_proj = np.copy(X)
    test_proj = test_proj / np.linalg.norm(test_proj, axis=1).reshape(-1, 1)

    # Reject projections to those orthonormal vectors
    for orthonormal_vector in q:
        cos = np.squeeze(cosine_similarity(test_proj, orthonormal_vector.reshape(1, -1)))
        rejection_features = cos.reshape(-1, 1) * np.repeat(orthonormal_vector.reshape(1, -1), cos.shape[0], axis=0) / np.linalg.norm(orthonormal_vector)
        test_proj = test_proj - rejection_features
        test_proj = test_proj / np.linalg.norm(test_proj, axis=1).reshape(-1, 1)
    return test_proj

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-d', '--dataset', type=str, default='civilcomments')
    parser.add_argument('-m', '--model', type=str, default='hf_sim')

    args = parser.parse_args()

    dataset_name = args.dataset
    model = args.model

    root_dir = f'{dataset_name}_features/'
    load_dir = f'{dataset_name}_features/features_{model}'

    subdirs = [os.path.join(load_dir, p) for p in os.listdir(load_dir) if 'text' in p]
    # print(len(subdirs), subdirs[0])
    metadata = np.load(os.path.join(load_dir, '0', 'metadata.npy'))
    text_emb = np.load(os.path.join(load_dir, '0', 'emb.npy'))
    
    if dataset_name in [const.AMAZON_NAME, const.GENDER_BIAS_NAME]:
        W_all = []
        for i, subdir in enumerate(subdirs):
            spurious_emb = np.load(os.path.join(subdir, 'texts.npy'))
            spurious_emb = np.expand_dims(spurious_emb, axis=0)
            W_all.append(spurious_emb)
        W_all = np.stack(W_all)
        if len(W_all.shape) > 3:
            W_all = W_all.squeeze()
        projected_X = project_orthonormal_all(spurious_emb, text_emb)
        save_path = os.path.join(load_dir, 'text_emb_0', 'projected_emb_single.npy')
        np.save(save_path, projected_X) 
        print(f'projection saved to {save_path}')
    else:
        for i, subdir in enumerate(subdirs):
            if 'projected_emb_single.npy' in os.listdir(subdir):
                continue
            spurious_emb = np.load(os.path.join(subdir, 'texts.npy'))
            projected_X = []
            for row in tqdm(range(metadata.shape[0])):
                x = text_emb[row, :]
                metadata_row = metadata[row, :]
                mentioned_meta = np.argwhere(metadata_row == 1).flatten()
                if len(mentioned_meta) == 0:
                    projected_X.append(x)
                    continue
                else:
                    idxs = []
                    for mention_id in mentioned_meta:
                        idxs.append(mention_id)
                    W = []
                    for idx in idxs:
                        W.append(spurious_emb[idx, :, :])
                    W = np.stack(W, axis=0)
                    x_proj = project_orthonormal(W, x)
                    projected_X.append(x_proj)
            projected_X = np.vstack(projected_X)
            np.save(os.path.join(subdir, 'projected_emb_single.npy'), projected_X)
            
