import random
import timeit
import functools
import sympy
import json
import os

from argparse import ArgumentParser
from Parsers import BinaryParser, EquationParser
from config import configs, Config
from Horn import evaluate, learn_horn_envelope
from transformers import pipeline



def get_attribute_vector(length, allow_zero=True):
    """
    Generate a one hot encoded sample vector of a given length. If allow_zero=True then the zero
    vector is possible.

    Args:
        length (int): The length of the vector to be generated.
        allow_zero (bool): If True, the vector can be all zeros. If False, one element will be 1.

    Returns:
        list: A list of integers (0 or 1) of the specified length. If allow_zero is True, the list can be all zeros.
              If allow_zero is False, one element in the list will be 1.
    """

    if allow_zero:
        zero_vec_prob = 1/(length+1)
        if random.choices([True, False], weights=[zero_vec_prob, 1-zero_vec_prob])[0]: 
            return [0]*length
    
    attribute_vector = [0]*length
    attribute_vector[random.randint(0, length-1)] = 1
    return attribute_vector

def create_sample(binary_parser:BinaryParser, unmasking_model, verbose=False):
    
    for i in range(binary_parser.pac_hypothesis_space):
        # Generate sample vector
        sample_vector = []
        for attribute, values in binary_parser.features.items():
            if attribute != 'gender':
                sample_vector = [*sample_vector, *get_attribute_vector(len(values), allow_zero=True)]
        
        gender_vector = get_attribute_vector(2, allow_zero=False)
        assert 1 in gender_vector
        
        # Query model and get label (positive/negative counterexample?)
        sentence = binary_parser.binary_to_sentence(sample_vector, is_masked=True)
        prediction = get_prediction(unmasking_model, sentence)
        if prediction != '': 
            break
        
    label = (prediction.lower() in ['she', 'woman', 'female'] and gender_vector[0] == 1) or (prediction.lower() in ['he', 'man', 'male'] and gender_vector[1] == 1)
    
    # Combine the sample
    sample_vector = [*sample_vector, *gender_vector]
    if verbose: 
        print(sentence)
        print((sample_vector, prediction, gender_vector, label))
    
    return (sample_vector, label)

def equivalence_oracle(hypothesis, unmasking_model, binary_parser:BinaryParser):
    (H, Q, _) = hypothesis

    # Reduce the hypothesis to a conjunction of clauses
    hypo = functools.reduce(lambda x,y: x & y, H.union(Q))

    for i in range(binary_parser.pac_hypothesis_space):
        (assignment, label) = create_sample(binary_parser, unmasking_model)
        if not (bool(label) == evaluate(hypo, assignment, binary_parser.V)): 
            return (assignment, i, bool(label), {})

    # No counterexample was found, the hypothesis is true.
    return (True, binary_parser.pac_hypothesis_space, True, {})

def membership_oracle(assignment, unmasking_model, binary_parser:BinaryParser):

    sample_vector = assignment[:-2]
    gender_vector = assignment[-2:]

    # Since the gender is the masked token the gender variable is required.
    # We can never get rules that say that something that does not contain gender is not true in this case.
    if 1 not in gender_vector: 
        return True
    is_female = gender_vector[0] == 1
    gender_pronouns = ['she', 'woman', 'female'] if is_female else ['he', 'man', 'male']

    sentence = binary_parser.binary_to_sentence(sample_vector, is_masked=True)
    prediction = get_prediction(unmasking_model, sentence)
    #tokens = [pred['token_str'].strip() for pred in predictions]
    
    #for token in tokens:
    if prediction in gender_pronouns: 
        return True

    return False

def get_prediction(unmasking_model, sentence, gender_preferred=True):
    """
    Gets the prediction of the unmasking model. If binary is set to
    True then it returns 0 or 1 for only 'He/he' and 'She/she' pronouns
    else returns the most probable token.

    This method ensures the correct masking token is used.

    Args:
        unmasking_model:  The language model
        sentence:         The masked sentence
        gender_preferred: Returns gender pronoun first
    
    Returns:
        A string with the best prediction of the model.
    """

    sentence = sentence.replace('<mask>', unmasking_model.tokenizer.mask_token)
    predictions = unmasking_model(sentence, top_k=10)

    if gender_preferred:
        tokens = [pred['token_str'].strip() for pred in predictions]
        for token in tokens:
            if token.lower() in ['she', 'he', 'woman', 'man', 'female', 'male']: 
                return token.lower()
    return ''

def learn_with_unmasking_model(args, config):

    binary_parser = BinaryParser(config, args.epsilon, args.delta)
    equation_parser = EquationParser(binary_parser)
    unmasking_model = pipeline('fill-mask', model=args.model_name)
    
    # Create lambda functions for asking the membership and equivalence oracles.
    ask_membership_oracle  = lambda assignment : membership_oracle(assignment, unmasking_model, binary_parser)
    ask_equivalence_oracle = lambda hypothesis : equivalence_oracle(hypothesis, unmasking_model, binary_parser) 

    start = timeit.default_timer()
    H, Q, metadata = learn_horn_envelope(ask_membership_oracle, ask_equivalence_oracle, (binary_parser, equation_parser), args, num_iterations=args.iterations)
    stop = timeit.default_timer()
    runtime = stop-start

    metadata['background'] = [sympy.srepr(line) for line in list(binary_parser.background)]
    metadata['h_rules'] = [sympy.srepr(line) for line in list(H) if line not in list(binary_parser.background)]
    metadata['q_rules'] = [sympy.srepr(line) for line in list(Q)]
    metadata['rules'] = [equation_parser.parse(clause) for clause in H if clause not in binary_parser.background]

    return (runtime, metadata, (H, Q))


if __name__ == '__main__':

    argparser = ArgumentParser()
    argparser.add_argument('--iterations', type=int, default=-1)
    argparser.add_argument('--epsilon', type=float, default=0.2)
    argparser.add_argument('--delta', type=float, default=0.1)
    argparser.add_argument('--verbose', action="store_true")
    argparser.add_argument('--is_bert', action="store_true")
    argparser.add_argument('--config_idx', type=int, default=0)
    argparser.add_argument('--id', type=int, default=-1)
    argparser.add_argument('--model_name', type=str, default='roberta-base')
    argparser.add_argument('--output_dir', type=str, default='results')
    args = argparser.parse_args()


    config = Config(**configs[args.config_idx])
    config.verbose = args.verbose
    
    
    print("This is a run with older models")
    if args.iterations >= 0:
        print(f"This will run up to {args.iterations} iterations")
    
    assert args.model_name in ['roberta-base', 'roberta-large', 'bert-base-cased', 'bert-large-cased']
    

    print(f'Model: {args.model_name}')

    (runtime, metadata, (H, Q)) = learn_with_unmasking_model(args, config)
    metadata['model name'] = args.model_name

    print('Sentences:      ', metadata['rules'])
    print('Runtime:        ', runtime)
    print('Terminated:     ', metadata['terminated'])
    print('Iterations:     ', metadata['iteration_time'])
    print('queries:        ', metadata['queries'])
    print('Errors:         ', metadata['errors'])
    print('Sample numbers: ', metadata['sample_numbers'])
    print('H:              ', metadata['h_rules'])
    print('Q:              ', metadata['q_rules'])
    
    
    job_id = os.environ.get('SLURM_JOB_ID')
    if args.id != -1:
        job_id = str(job_id) + '-' + str(args.id)
    os.makedirs(f'{args.output_dir}/{metadata["model name"]}', exist_ok=True)
    with open(f'{args.output_dir}/{metadata["model name"]}/{job_id}.json', 'w') as f:
        json.dump(metadata, f, ensure_ascii=False)
    

            

    