import timeit
import copy
import torch
import functools
import json
import sympy
import os
import re

from typing import Tuple, Set, List, Union
from argparse import ArgumentParser
from transformers import AutoModelForCausalLM, AutoTokenizer

from Horn import learn_horn_envelope, evaluate
from Parsers import BinaryParser, EquationParser
from config import configs, Config
from masked_model import get_attribute_vector

def construct_prompt(binary_parser, hypothesis_str, past_conversation):

    prompt = copy.deepcopy(binary_parser.config.EQ)
    if past_conversation == []:
        prompt[-1]['content'] = prompt[-1]['content'].format(hypo=hypothesis_str, systems_prompt=binary_parser.templates['systems_prompt'])
        return prompt

    prompt[-1]['content'] = prompt[-1]['content'].format(hypo=binary_parser.templates['base_hypothesis'], systems_prompt=binary_parser.templates['systems_prompt'])
    prompt.extend(past_conversation)
    prompt[-1]['content'] = f"The updated hypothesis is the following: {hypothesis_str}\n\n Please provide another counterexample to my hypothesis if possible.\n\n"

    return prompt
    
def membership_oracle(counterexample:List[int], model:AutoModelForCausalLM, tokenizer:AutoTokenizer, binary_parser:BinaryParser, device:str) -> Union[bool, str]:

    # Construct MQ prompt
    prompt = copy.deepcopy(binary_parser.config.MQ)
    prompt[-1]['content'] = prompt[-1]['content'].format(MQ_example=binary_parser.binary_to_MQ(counterexample))

    prompt, response = generate_response(model, tokenizer, prompt, device)
    
    if binary_parser.reasoning:
        if '</think>' in response:
            response = response[response.rfind("</think>")+len("</think>"):]         
        
    if binary_parser.config.verbose:
        print("\n============= MQ ===============\n")
        print(prompt)
        print("\n--------------------------------\n")
        print(response)
        print("\n================================\n")

    if 'it is possible' in response.lower():  return True
    elif 'it is not possible' in response.lower(): return False
    
    print(f"NOT VALID MQ:{response}") 
    
    return response

def generate_response(model:AutoModelForCausalLM, tokenizer:AutoTokenizer, machine_prompt:List[dict], device:str) -> str:

    tokenized_prompt = tokenizer.apply_chat_template(machine_prompt, tokenize=True, return_tensors="pt").to(model.device)
    prompt_length = tokenized_prompt.shape[1]
    with torch.no_grad():
        output = model.generate(tokenized_prompt, max_new_tokens=10000)
    prompt = tokenizer.decode(output[0, :prompt_length])
    response = tokenizer.decode(output[0, prompt_length:])
    
    return prompt, response

def equivalence_oracle(hypothesis:Tuple[Set, Set, List[List[int]]], model:AutoModelForCausalLM, tokenizer:AutoTokenizer, binary_parser:BinaryParser, equation_parser:EquationParser, device:str):

    (H, Q, previous_counterexamples) = hypothesis
    h_union_q = functools.reduce(lambda x,y: x & y, H.union(Q))
    
    # Rebuild conversation history
    past_conversation = []
    if previous_counterexamples:
        past_conversation.append({'role': 'assistant', 'content': 'Here are counterexamples to the hypothesis:\n\n'})
        for (past_counterexample, example_is_positive) in previous_counterexamples:
            counterexample_string = binary_parser.binary_to_sentence(past_counterexample)
            #replace_idx = counterexample_string.rfind('and is a')
            # sign = "CAN be a" if example_is_positive else "CANNOT be a"
            # past_conversation[-1]['content'] = past_conversation[-1]['content'] + f'{counterexample_string[:replace_idx] + sign + counterexample_string[replace_idx+len("and is a"):]}.\n'
            sign = " is possible" if example_is_positive else " is not possible"
            past_conversation[-1]['content'] = past_conversation[-1]['content'] + f'{counterexample_string + sign}.\n'

        past_conversation.append({'role': 'user', 'content': ''})
        
        # Deconstruct current hypothesis
        not_clauses = []
        imply_clauses = []
        for clause in [clause for clause in H if clause not in binary_parser.background]:

            # A => B
            if type(clause) is sympy.Implies:
                antecedent, consequent = clause.args
                ant = antecedent.args if not(len(antecedent.args) == 0) else tuple([antecedent])
                consequent = sympy.And(*list(set(consequent.args).difference(set(ant))))

                vector_ant = [0] * binary_parser.total_length
                for atom in antecedent.atoms():
                    vector_ant[binary_parser.V.index(atom)] = 1
                
                vector_con = [0] * binary_parser.total_length
                for atom in consequent.atoms():
                    vector_con[binary_parser.V.index(atom)] = 1
                imply_clauses.append(f'{binary_parser.binary_to_sentence(vector_ant, is_antecedent=True)}{binary_parser.binary_to_sentence(vector_con, is_consequent=True)},')
            
            # NOT(A and B)
            else:    
                vector = [0] * binary_parser.total_length
                for atom in clause.atoms():
                    vector[binary_parser.V.index(atom)] = 1
                not_clauses.append(f'{str(binary_parser.binary_to_sentence(vector)).replace(".", "")} is not possible')

        # Build hypothesis string representation
        hypothesis_str = binary_parser.templates['updated_hypothesis'].format(not_clauses='; '.join(not_clauses), imply_clauses=','.join(imply_clauses)) + '.'    
    else: 
        hypothesis_str = binary_parser.templates['base_hypothesis'] + '.'

    # Build EQ prompt
    base_prompt = construct_prompt(binary_parser, hypothesis_str, past_conversation)
    iteration_count = 0
    errors = {'format':0, 'values':0, 'duplicate': 0, 'validity': 0, 'thinking':0, 'summary format': 0, 'summary values': 0}

    if binary_parser.config.verbose:
        print("\n============== EQ ================\n")

    while iteration_count < binary_parser.config.error_threshold:

        prompt, response = generate_response(model, tokenizer, base_prompt, device)

        # Thinking check
        if binary_parser.reasoning:
            if '</think>' not in response:
                # Update base prompt to reflect error
                errors['thinking'] += 1
                iteration_count += 1
                base_prompt.append({'role': 'assistant', 'content': "..."})
                base_prompt.append({'role':'user', 'content': f"You thought for too long. Please provide a counterexample to the hypothesis: {hypothesis_str}. \nIf you think that my hypothesis accurately and completely describes the real world then reply with {binary_parser.templates['stop_prompt']}."})
                continue
            response = response[response.rfind("</think>")+len("</think>"):]          
        counterexample_binary, response_valid = binary_parser.sentence_to_binary(response)   

        if binary_parser.config.verbose:
            print(prompt)
            print("\n================================\n")
            print(response)
            print("\n================================\n")

        # Termination check
        if binary_parser.templates['stop_prompt'].lower() in response.lower():
            # Sometimes the model returns a counterexample along with the stop signal. We check if the response contains a valid counterexample.
            if not response_valid:
                print("\n================================\n")
                print("======== STOP =========")
                print("\n================================\n")
                return (True, iteration_count, True, errors)


        # Smaller models may need a summary step to reduce hallucinations
        if binary_parser.summary:
            summary_prompt = copy.deepcopy(binary_parser.config.string_templates['summary_prompt'])
            summary_prompt[0]['content'] = summary_prompt[0]['content'].format(response=response)

            if binary_parser.config.verbose:
                print("\n============== EQ SUMMARY ================\n")
            summary_flag = False
            
            while iteration_count < binary_parser.config.error_threshold: 
                summary_prompt, summary_response = generate_response(model, tokenizer, summary_prompt, device)
                # Thinking check
                if binary_parser.reasoning:
                    if '</think>' not in summary_response:
                        # Update base prompt to reflect error
                        errors['thinking'] += 1
                        iteration_count += 1
                        summary_prompt.append({'role': 'assistant', 'content': "..."})
                        summary_prompt.append({'role':'user', 'content': f"You thought for too long. Please provide a counterexample to the hypothesis: {hypothesis_str}. \nIf you think that my hypothesis accurately and completely describes the real world then reply with {binary_parser.templates['stop_prompt']}."})
                        continue
                    summary_response = summary_response[summary_response.rfind("</think>")+len("</think>"):]   
                summary_binary, summary_valid = binary_parser.sentence_to_binary(summary_response)
                summary_parsed = binary_parser.binary_to_sentence(summary_binary)

                if binary_parser.config.verbose:
                    print(summary_prompt)
                    print("\n--------------------------------\n")
                    print(summary_response)
                    print("\n--------------------------------\n")
                    print(summary_parsed)
                    print("\n================================\n")

                # Summary format check
                if not summary_valid:
                    summary_prompt = copy.deepcopy(binary_parser.config.string_templates['summary_hallucinate'])
                    summary_prompt[0]['content'] = summary_prompt[0]['content'].format(response=response, summary_response=summary_response)

                    if binary_parser.config.verbose:
                        print("\n============== EQ SUMMARY VALUES ERROR ================\n")
                    iteration_count += 1
                    errors['summary values'] += 1
                    continue

                if (not (re.search(r'\b' + 'cannot' + r'\b', summary_response.lower()) or re.search(r'\b' + 'can' + r'\b', summary_response.lower()))):
                
                    summary_prompt = copy.deepcopy(binary_parser.config.string_templates['summary_format'])
                    summary_prompt[0]['content'] = summary_prompt[0]['content'].format(response=response, summary_response=summary_response)

                    if binary_parser.config.verbose:
                        print("\n============== EQ SUMMARY FORMAT ERROR ================\n")
                    iteration_count += 1
                    errors['summary format'] += 1
                    continue
                
                summary_flag = True
                is_positive_counterexample = True if re.search(r'\b' + 'can' + r'\b', summary_response) else False
                response = summary_response
                break
            
            # If we do not get a summary, then we have hit the error threshold and we terminate.
            if not summary_flag:  
                print("======== NO SUMMARY ========")
                return (False, iteration_count, False, errors) 
            
        # Valid format check
        if (not (re.search(r'\b' + 'cannot' + r'\b', response.lower()) or re.search(r'\b' + 'can' + r'\b', response.lower()) or re.search(r'\b' + 'is not possible' + r'\b', response.lower()) or re.search(r'\b' + 'is possible' + r'\b', response.lower()))):
            # Update base prompt to reflect error
            base_prompt.append({'role': 'assistant', 'content': f"{response}\n"})
            base_prompt.append({'role': 'user', 'content': f"Your counterexample did not contain 'can' or 'cannot'. Please provide a counterexample to the hypothesis: {hypothesis_str}. \nIf you think that my hypothesis accurately and completely describes the real world then reply with {binary_parser.templates['stop_prompt']}."})

            if binary_parser.config.verbose:
                print("\n============== EQ INVALID FORMAT ================\n")
            iteration_count += 1
            errors['format'] += 1
            continue

        is_positive_counterexample = True if re.search(r'\b' + 'can' + r'\b', response.lower()) else False
        counterexample_binary, response_valid = binary_parser.sentence_to_binary(response)   
        if not response_valid:
            # Update base prompt to reflect error
            base_prompt.append({'role': 'assistant', 'content': f"{response}\n"})
            base_prompt.append({'role': 'user', 'content': f"Your counterexample contained invalid combination of values. You may only use one value from each category in the counterexample. Please provide a counterexample in the format '{binary_parser.config.string_templates['reply_format']}' that only contains the allowed attribute values. If you think that my hypothesis accurately and completely describes the real world then reply with {binary_parser.templates['stop_prompt']}.\n#COUNTEREXAMPLE"})
            
            if binary_parser.config.verbose:
                print("\n============== EQ INVALID VALUES ================\n")
            iteration_count += 1
            errors['values'] += 1
            continue

        # Duplicate check
        if (counterexample_binary, is_positive_counterexample) in previous_counterexamples:
            # Update base prompt to reflect error
            base_prompt.append({'role': 'assistant', 'content': f"{response}\n"})
            base_prompt.append({'role': 'user', 'content': f"The counterexample in the assistant reply is a duplicate counterexample. Provide one counterexample in the format '{binary_parser.config.string_templates['reply_format']}' that is not a duplicate of a previously given counterexample. If you think that my hypothesis accurately and completely describes the real world then reply with {binary_parser.templates['stop_prompt']}. \n#COUNTEREXAMPLE"})

            if binary_parser.config.verbose:
                print("\n============== EQ DUPLICATE ================\n")
            iteration_count += 1
            errors['duplicate'] += 1
            continue   

        # Valid counterexample check
        hypothesis_eval = evaluate(h_union_q, counterexample_binary, binary_parser.V)
        if is_positive_counterexample == hypothesis_eval:
            # Update base prompt to reflect error
            base_prompt.append({'role': 'assistant', 'content': f"\n<Counterexample>{response}</Counterexample>\n"})
            if is_positive_counterexample:
                base_prompt.append({'role': 'user', 'content': f"The Counterexample is not an actual counterexample since it is possible both the real world and the current hypothesis. Provide one valid counterexample in the format '{binary_parser.config.string_templates['reply_format']}'.\n"})
            else:
                base_prompt.append({'role': 'user', 'content': f"The Counterexample is not an actual counterexample since it is not possible both the real world and the hypothesis. Provide one valid counterexample in the format '{binary_parser.config.string_templates['reply_format']}'.\n"})

            if binary_parser.config.verbose:
                print("\n============== EQ VALIDATION FAIL ================\n")
            iteration_count += 1
            errors['validity'] += 1
            continue
        
        # Counterexample found
        return (counterexample_binary, iteration_count, is_positive_counterexample, errors)
    
    # Counterexample not found within error threshold
    return (False, iteration_count, False, errors)
    

def equivalence_oracle_masked(hypothesis:Tuple[Set, Set, List, List], 
                              model:AutoModelForCausalLM, 
                              tokenizer:AutoTokenizer, 
                              binary_parser:BinaryParser, 
                              device:str) -> Union[bool, Tuple[List[int], int, bool, List[str]]]:
    
    (H, Q, previous_counterexamples) = hypothesis
    h_union_q = functools.reduce(lambda x,y: x & y, H.union(Q))
    

    for i in range(binary_parser.pac_hypothesis_space):
        # Generate sample vector
        sample_vector = []
        for attribute, values in binary_parser.features.items():
            sample_vector = [*sample_vector, *get_attribute_vector(len(values), allow_zero=True)]
        
        reply = membership_oracle(sample_vector, model, tokenizer, binary_parser, device)
        
        if type(reply) is str: 
            continue
        if bool(reply) != evaluate(h_union_q, sample_vector, binary_parser.V): 
            return (sample_vector, i, bool(reply), {})

    # No counterexamples were found, hypothesis is considered equivalent in PAC-learning
    return (True, binary_parser.pac_hypothesis_space, True, {})

def learn_with_modern_model(args, device, config):
    
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        pad_token_id=tokenizer.eos_token_id,
        device_map="auto"
    )
    model.eval()
    
    # Define variables
    binary_parser = BinaryParser(config, args.epsilon, args.delta)
    if args.reasoning_model:
        binary_parser.reasoning = True
    if args.summary:
        binary_parser.summary = True
    equation_parser = EquationParser(binary_parser)

    # Create oracles
    ask_membership_oracle  = lambda assignment : membership_oracle(assignment, model, tokenizer, binary_parser, device)  
        
    if args.masked:
        ask_equivalence_oracle = lambda hypothesis : equivalence_oracle_masked(hypothesis, model, tokenizer, binary_parser, device)  
    else:
        ask_equivalence_oracle = lambda hypothesis : equivalence_oracle(hypothesis, model, tokenizer, binary_parser, equation_parser, device)  

    # Run learning algorithm
    start = timeit.default_timer()
    H, Q, metadata = learn_horn_envelope(ask_membership_oracle, ask_equivalence_oracle, (binary_parser, equation_parser), args, num_iterations=args.iterations)
    stop = timeit.default_timer()
    runtime = stop-start

    metadata['h_rules'] = [sympy.pretty(line, use_unicode=False) for line in list(H) if line not in binary_parser.background]
    metadata['q_rules'] = [sympy.pretty(line, use_unicode=False) for line in list(Q)]
    metadata['rules'] = [equation_parser.parse(clause) for clause in H if clause not in binary_parser.background]
    metadata['features'] = binary_parser.features

    return (runtime, metadata, (H, Q))
    
if __name__ == '__main__':

    argparser = ArgumentParser()
    argparser.add_argument('--iterations', type=int, default=-1)
    argparser.add_argument('--masked', action="store_true")
    argparser.add_argument('--summary', action="store_true") 
    argparser.add_argument('--model_path', type=str, required=True)
    argparser.add_argument('--verbose', action="store_true")
    argparser.add_argument('--epsilon', type=float, default=0.2)
    argparser.add_argument('--delta', type=float, default=0.1)
    argparser.add_argument('--id', type=str, default='0')
    argparser.add_argument('--reasoning_model', action="store_true")
    argparser.add_argument('--config_idx', type=int, default=0)
    argparser.add_argument('--output_dir', type=str, default='results')
    args = argparser.parse_args()

    print("This is a sampling run") if args.masked else print("This is a non-sampling run.")
    
    if args.iterations >= 0:
        print(f"This will run up to {args.iterations} times")
    print('Using config: ', args.config_idx)
    print('The model is: ', args.model_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = Config(**configs[args.config_idx])
    config.verbose = args.verbose
    (runtime, metadata, (H, Q)) = learn_with_modern_model(args, device, config)
    metadata['runtime'] = runtime

    print('Sentences:  ', metadata['rules'])
    print('Runtime:    ', runtime)
    print('Terminated: ', metadata['terminated'])
    print('Iterations: ', metadata['iteration_time'])
    print('queries:    ', metadata['queries'])
    print('Errors:     ', metadata['errors'])

    metadata['model name'] = args.model_path.split('/')[-1]

    job_id = os.environ.get('SLURM_JOB_ID')
    
    os.makedirs(f'{args.output_dir}/{metadata["model name"]}/config{args.config_idx}', exist_ok=True)
    with open(f'{args.output_dir}/{metadata["model name"]}/config{args.config_idx}/{job_id}.json', 'w') as f:
        json.dump(metadata, f, ensure_ascii=False)