from itertools import combinations
import collections

import pandas as pd
import numpy as np

vec_dim = 1 # row vectors

masc_names    = ["john", "paul", "mike", "kevin", "steve", "greg", "jeff", "bill"]
fem_names     = ["amy", "joan", "lisa", "sarah", "diana", "kate", "ann", "donna"]

career_words  = ["executive", "management", "professional", "corporation", "salary", "office", "business", "career"]
home_words    = ["home", "parents", "children", "family", "cousins", "marriage", "wedding", "relatives"]

fem_words1    = ["female", "woman", "girl", "sister", "she", "her", "hers", "daughter"]
masc_words1   = ["male", "man", "boy", "brother", "he", "him", "his", "son"]

art_words1    = ["poetry", "art", "dance", "literature", "novel", "symphony", "drama", "sculpture"]
math_words    = ["math", "algebra", "geometry", "calculus", "equations", "computation", "numbers", "addition"]

fem_words2    = ["sister", "mother", "aunt", "grandmother", "daughter", "she", "hers", "her"]
masc_words2   = ["brother", "father", "uncle", "grandfather", "son", "he", "his", "him"]

science_words = ["science", "technology", "physics", "chemistry", "einstein", "nasa", "experiment", "astronomy"]
art_words2    = ["poetry", "art", "shakespeare", "dance", "literature", "novel", "symphony", "drama"]

masc_adjs     = ['aggressive', 'adventurous', 'dominant', 'forceful', 'strong']
fem_adjs      = ['sentimental','emotional', 'affectionate', 'sympathetic', 'talkative']

masc_jobs     = ['plumber', 'mechanic', 'carpenter', 'electrician', 'machinist', 'engineer', 'programmer', 'architect', 'officer', 'paramedic']
fem_jobs      = ['therapist', 'planner', 'librarian', 'paralegal', 'nurse', 'receptionist', 'hairdresser', 'nutritionist', 'hygienist', 'pathologist']

masc_adjs     =  ['handsome', 'aggressive', 'tough', 'courageous', 'strong', 'forceful', 'arrogant', 'egotistical', 'boastful', 'dominant']
fem_adjs      =  ['affectionate', 'sensitive', 'appreciative', 'sentimental', 'sympathetic', 'nagging', 'fussy', 'emotional']

flower_words     = ["aster", "clover", "hyacinth", "marigold", "poppy",\
                    "azalea", "crocus", "iris", "orchid", "rose",\
                    "bluebell", "daffodil", "lilac", "pansy", "tulip",\
                    "buttercup", "daisy", "lily", "peony", "violet",\
                    "carnation", "gladiola", "magnolia", "petunia", "zinnia"]
insect_words     = ["ant", "caterpillar", "flea", "locust", "spider",\
                    "bedbug", "centipede", "fly", "maggot", "tarantula",\
                    "bee", "cockroach", "gnat", "mosquito", "termite",\
                    "beetle", "cricket", "hornet", "moth", "wasp",\
                    "blackfly", "dragonfly", "horsefly", "roach", "weevil"]

pleasant_words   = ["caress", "freedom", "health", "love", "peace",\
                    "cheer", "friend", "heaven", "loyal", "pleasure",\
                    "diamond", "gentle", "honest", "lucky", "rainbow",\
                    "diploma", "gift", "honor", "miracle", "sunrise",\
                    "family", "happy", "laughter", "paradise", "vacation"]
unpleasant_words = ["abuse", "crash", "filth", "murder", "sickness",\
                    "accident", "death", "grief", "poison", "stink",\
                    "assault", "disaster", "hatred", "pollute", "tragedy",\
                    "divorce", "jail", "poverty", "ugly", "cancer",\
                    "kill", "rotten", "vomit", "agony", "prison"]

music_words      = ["bagpipe", "cello", "guitar", "lute", "trombone",\
                    "banjo", "clarinet", "harmonica", "mandolin", "trumpet",\
                    "bassoon", "drum", "harp", "oboe", "tuba",\
                    "bell", "fiddle", "harpsichord", "piano", "viola",\
                    "bongo", "flute", "horn", "saxophone", "violin"]
weapon_words     = ["arrow", "club", "gun", "missile", "spear",\
                    "axe", "dagger", "harpoon", "pistol", "sword",\
                    "blade", "dynamite", "hatchet", "rifle", "tank",\
                    "bomb", "firearm", "knife", "shotgun", "teargas",\
                    "cannon", "grenade", "mace", "slingshot", "whip"]

def preprocess(X, Y, A, B):
    A = A / np.expand_dims(np.linalg.norm(A, axis=vec_dim), vec_dim)
    B = B / np.expand_dims(np.linalg.norm(B, axis=vec_dim), vec_dim)
    diff = A.sum(axis=1-vec_dim)/A.shape[1-vec_dim] - B.sum(axis=1-vec_dim)/B.shape[1-vec_dim]

    X = X / np.expand_dims(np.linalg.norm(X, axis=vec_dim), vec_dim)
    Y = Y / np.expand_dims(np.linalg.norm(Y, axis=vec_dim), vec_dim)

    return X, Y, A, B, diff

def weat_test_statistic(X, Y, A, B):
    # weat rescaled so that X and Y may not be equal sizes
    # this doesn't change the p value or effect size if X and Y are equal sizes

    X, Y, A, B, diff1 = preprocess(X,Y,A,B)
    diff2 = (X.sum(axis=1-vec_dim) / X.shape[1-vec_dim]) - (Y.sum(axis=1-vec_dim) / Y.shape[1-vec_dim])

    return diff2.dot(diff1)

def p_value(X, Y, A, B, SAMPLE_SIZE=3000000):
    X, Y, A, B, diff = preprocess(X,Y,A,B)

    XY = np.concatenate([X, Y], axis=1-vec_dim)
    scores = np.zeros((XY.shape[1-vec_dim],))

    if vec_dim == 0:
        scores[:] = (XY.T @ diff.reshape((-1,1))).squeeze()
    else:
        scores[:] = (XY @ diff.reshape((-1,1))).squeeze()

    target = np.sum(scores[:X.shape[1-vec_dim]])

    count = 0
    total = 0
    if scores.shape[0] <= 24:
        for c in combinations(range(scores.shape[0]), X.shape[1-vec_dim]):
            if np.sum(scores[np.array(c)]) >= target:
                count += 1
            total += 1
        return count / total

    else:
        rng = np.random.default_rng()
        
        inds = np.repeat(np.arange(scores.shape[0]).reshape(1,-1), SAMPLE_SIZE, axis = 0)

        for i in range(Y.shape[1-vec_dim]):
            mask = np.ones(inds.shape, dtype=bool)
            mask[np.arange(SAMPLE_SIZE), rng.integers(scores.shape[0] - i,size=SAMPLE_SIZE)] = False
            inds = inds[mask].reshape(inds.shape[0], -1)

        return np.sum(np.sum(scores[inds], axis=1) >= target) / SAMPLE_SIZE
def effect_size(X, Y, A, B):
    X, Y, A, B, diff1 = preprocess(X,Y,A,B)

    diff2 = (X.sum(axis=1-vec_dim) / X.shape[1-vec_dim]) - (Y.sum(axis=1-vec_dim) / Y.shape[1-vec_dim])

    XY = np.concatenate([X, Y], axis=1-vec_dim)
    scores = np.zeros((XY.shape[1-vec_dim],))

    if vec_dim == 0:
        scores[:] = (XY.T @ diff1.reshape((-1,1))).squeeze()
    else:
        scores[:] = (XY @ diff1.reshape((-1,1))).squeeze()

    return diff2.dot(diff1) / np.std(scores)


def weat(vec_arr, w2i, tests=('mf_ch', 'mf_ma', 'mf_sa'), require_equal=True, force_equal=True, Xs=None,Ys=None,As=None,Bs=None, names=None):
    df = pd.DataFrame(np.zeros((len(tests), 4)))

    j = 0
    for i, t in enumerate(tests):
        if t == 'mf_ch':
            wordX = masc_names
            wordY = fem_names
            wordA = career_words
            wordB = home_words

            df.iloc[i,0] = 'Career/Home--Masc/Fem'
        elif t == 'mf_ma':
            wordA = masc_words1
            wordB = fem_words1
            wordX = math_words
            wordY = art_words1

            df.iloc[i,0] = 'Math/Art--Masc/Fem'
        elif t == 'mf_sa':
            wordX = science_words
            wordY = art_words2
            
            wordA = masc_words2
            wordB = fem_words2

            df.iloc[i,0] = 'Science/Art--Masc/Fem'
        elif t == 'ma_mf':
            wordX = math_words
            wordY = art_words1
            wordA = masc_words1
            wordB = fem_words1

            df.iloc[i,0] = 'Math/Art--Male/Female'
        elif t == 'sa_mf':
            wordX = science_words
            wordY = art_words2
            wordA = masc_words2
            wordB = fem_words2

            df.iloc[i,0] = 'Science/Art--Male/Female'
        elif t == 'fi_pu':
            wordX = flower_words
            wordY = insect_words
            wordA = pleasant_words
            wordB = unpleasant_words

            df.iloc[i,0] = 'Flower/Insect--Pleasant/Unpleasant'
        elif t == 'mw_pu':
            wordX = music_words
            wordY = weapon_words
            wordA = pleasant_words
            wordB = unpleasant_words

            df.iloc[i,0] = 'Music Instruments/Weapons--Pleasant/Unpleasant'
            
        elif t == 'in_ja':
            wordX = masc_jobs
            wordY = fem_jobs
            wordA = masc_adjs
            wordB = fem_adjs
            
            df.iloc[i,0] = 'Indirect Jobs/Adjs'
            
        elif t == 'in_ma':
            wordX = math_words
            wordY = art_words1
            wordA = masc_adjs
            wordB = fem_adjs
            
            df.iloc[i,0] = 'Indirect MathArt/Adjs'
        
        elif t == 'in_sa':
            wordX = science_words
            wordY = art_words2
            wordA = masc_adjs
            wordB = fem_adjs
            
            df.iloc[i,0] = 'Indirect ScienceArt/Adjs'
            
        elif t == 'in_ca':
            wordX = career_words
            wordY = home_words
            wordA = masc_adjs
            wordB = fem_adjs
            
            df.iloc[i,0] = 'Indirect CareerHome/Adjs'
            
        elif t == 'sa_names':
            wordX = masc_names
            wordY = fem_names
            wordA = science_words
            wordB = art_words2
            
            
            df.iloc[i,0] = 'Science/Art--Names'
            
        elif t == 'ma_names':
            wordX = masc_names
            wordY = fem_names
            wordA = math_words
            wordB = art_words1

            df.iloc[i,0] = 'Math/Art--Names'      
            
        elif t == 'ch_nonames':
            wordX = masc_words2
            wordY = fem_words2
            wordA = career_words
            wordB = home_words
            
            df.iloc[i,0] = 'Career/Family--NoNames'

        elif t is None:
            if not isinstance(Xs, collections.abc.Iterable) \
               or isinstance(Xs, str) \
               or not isinstance(Ys, collections.abc.Iterable) \
               or isinstance(Ys, str) \
               or not isinstance(As, collections.abc.Iterable) \
               or isinstance(As, str) \
               or not isinstance(Bs, collections.abc.Iterable) \
               or isinstance(Bs, str) \
               or not isinstance(Xs[0], collections.abc.Iterable) \
               or isinstance(Xs[0], str) \
               or not isinstance(Ys[0], collections.abc.Iterable) \
               or isinstance(Ys[0], str) \
               or not isinstance(As[0], collections.abc.Iterable) \
               or isinstance(As[0], str) \
               or not isinstance(Bs[0], collections.abc.Iterable) \
               or isinstance(Bs[0], str) \
               or not isinstance(Xs[0][0], str) \
               or not isinstance(Ys[0][0], str) \
               or not isinstance(As[0][0], str) \
               or not isinstance(Bs[0][0], str) \
               or not isinstance(names, collections.abc.Iterable) \
               or not isinstance(names[0], str):

                raise ValueError("Xs, Ys, As, Bs must all be assigned and be an iterable of iterable of strings.  Strings and iterables of strings are not allowed.  names must also be assigned as an iterable of strings.")

            wordX = Xs[j]
            wordY = Ys[j]
            wordA = As[j]
            wordB = Bs[j]

            df.iloc[i,0] = names[j]

            j += 1
        else:
            raise ValueError(f"{t} is not a valid test identifier")

        X = np.array([vec_arr[w2i[w],:] for w in wordX if w in w2i])
        Y = np.array([vec_arr[w2i[w],:] for w in wordY if w in w2i])
        A = np.array([vec_arr[w2i[w],:] for w in wordA if w in w2i])
        B = np.array([vec_arr[w2i[w],:] for w in wordB if w in w2i])

        if require_equal and X.shape[0] != Y.shape[0]:
            if not force_equal:
                raise ValueError("X and Y classes do not have the same number of vectors")
            else:
                rng = np.random.default_rng()
                keep_num = min(X.shape[0], Y.shape[0])
                X = rng.choice(X,size=keep_num,replace=False)
                Y = rng.choice(Y,size=keep_num,replace=False)

        df.iloc[i,1] = weat_test_statistic(X,Y,A,B)
        df.iloc[i,2] = p_value(X,Y,A,B)
        df.iloc[i,3] = effect_size(X,Y,A,B)

    df.columns = ['Test', 'Test Statistic', 'p-value', 'Effect Size']
    df.set_index('Test', inplace=True)

    return df
