'''Computes similarity matrix'''
import os
import sys
import argparse
import random
import torch
import numpy as np
from gensim.models import KeyedVectors

from utils.model import get_word_embeddings, get_per_embeddings_type
from robustness import defaults
from robustness.datasets import DATASETS
from robustness.main import setup_args


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--layer_num', default=17, type=int,
        help='Which layer to get the representation from model')
    
    # NLP-Related
    parser.add_argument('--words_path', default=None, type=str,
        help='Path to words')
    parser.add_argument('--embeddings_type', default=None, type=str,
        choices=['debiased', 'gender', 'random'], help='Types of embeddings to get from the embeddings (for GN-Glove)')
    parser.add_argument('--embeddings_op', default=None, type=str,
        choices=['association', 'gender_association'],
        help='Embedding operation (association: emb(w1) - emb(w2); gender_association: [emb(w1) - emb(w2)] - [emb(he) - emb(she)]).')
    parser.add_argument('--norm_embeddings_op', default=False, action='store_true',
        help='Whether to normalize the individual embeddings')
    parser.add_argument('--seed', default=0, type=int, help='Seed')

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)

    args = parser.parse_args()
    return args

def get_path(args):
    print('path to save', args.out_dir)
    input()
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    return args.out_dir

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    path = get_path(args)
    # logger = create_logger(path)

    model = KeyedVectors.load_word2vec_format(args.resume, binary=False)
    print(f'Loaded model from {args.resume}')
    features, words = get_word_embeddings(
        model, args.words_path, args.embeddings_type, args.embeddings_op,
        args.norm_embeddings_op, return_words=True)
    print(f'features shape {features.shape}')

    gender_vector = model['he'] - model['she']
    gender_vector = gender_vector / np.linalg.norm(gender_vector)
    gender_vector = get_per_embeddings_type(gender_vector, embeddings_type=args.embeddings_type)

    gender_tensor = torch.from_numpy(np.expand_dims(gender_vector, 0))
    magnitudes = torch.mm(features, torch.transpose(gender_tensor, 1, 0)).squeeze(1)
    print(magnitudes.shape)
    feature_type = args.embeddings_type
    if feature_type is None:
        feature_type = 'all'
    embeddings_op = args.embeddings_op
    if args.embeddings_op is None:
        embeddings_op = ''
    if args.embeddings_op is not None and args.norm_embeddings_op:
        feature_type = f'norm_{feature_type}'
    if not os.path.exists(os.path.join(path, embeddings_op)):
        os.makedirs(os.path.join(path, embeddings_op))
    torch.save(magnitudes.cpu().detach(), os.path.join(path, embeddings_op, f'{feature_type}_magnitudes.pt'))
    print(f"Saved magnitudes in {os.path.join(path, embeddings_op, f'{feature_type}_magnitudes.pt')}")

if __name__=='__main__':
    main()
