import itertools
import random
import torch
import re
from collections import defaultdict, Counter
import itertools
import math 

def get_value_positions(tokens, original_tokens):
    mask = []
    template_idx = 0
    for token in tokens:
        if template_idx < len(original_tokens) and token == original_tokens[template_idx]:
            mask.append(0) 
            template_idx += 1
        else:
            if token in original_tokens[template_idx:]:
                while template_idx < len(original_tokens) and original_tokens[template_idx] != token:
                    template_idx += 1
                mask.append(0)
                template_idx += 1
            else:
                mask.append(1)
            
    value_positions = []
    current_value = []
    for j, m in enumerate(mask):
        if m == 1:
            current_value.append(j)
        elif current_value:
            value_positions.append(current_value)
            current_value = []
    if current_value:
        value_positions.append(current_value)
        
    return value_positions


def init_example_generator(model, templates, value_dictionary, batch_size=1, shuffle=False):
    # Extract attribute orders from the templates
    def name_order_pattern(order):
        return ''.join({v: chr(97+i) for i, v in enumerate(dict.fromkeys(order))}[x] for x in order)
    
    attribute_orders = [tuple(re.findall(r'\{([^}]*)\}', template)) for template in templates]
    unique_orders = list(set(attribute_orders))
    order_patterns = {order: name_order_pattern(order) for order in unique_orders}

    # Determine the minimal number of distinct values needed for each attribute
    minimal_values = {attr: len(set(value_dictionary[attr])) for attr in value_dictionary}

    def generate_unique_combinations(min_values):
        attributes = list(value_dictionary.keys())
        value_combinations = []
        for comb in itertools.product(*[value_dictionary[attr][:min_values[attr]] for attr in attributes]):
            if len(set(comb)) == len(comb):  # Ensure unique attribute values in the combination
                value_combinations.append(comb)
        return value_combinations

    value_combinations = generate_unique_combinations(minimal_values)
    n_examples = len(value_combinations) * len(templates)

    # Group templates by attribute order
    counter_templates = defaultdict(list)
    for i, template in enumerate(templates):
        attr_order = tuple(re.findall(r'\{([^}]*)\}', template))
        for counter_order in attribute_orders:
            if(counter_order != attr_order):
                counter_templates[order_patterns[counter_order]].append(i)

    # Shuffle if needed
    if shuffle:
        random.shuffle(value_combinations)
        shuffled_indices = list(range(n_examples))
        random.shuffle(shuffled_indices)
        index_mapping = {i: shuffled_indices[i] for i in range(n_examples)}
    else:
        index_mapping = {i: i for i in range(n_examples)}

    # Counter to balance the selection of counter-attributes (including order)
    counterexample_counts = Counter()

    def get_example(idx, get_counterexample=True, attr_val_dict=None, counter_attr_specified=None):
        if idx >= n_examples:
            raise IndexError("Index out of range")

        if attr_val_dict is not None:
            get_counterexample = False

        results = []
        for i in range(batch_size):
            current_idx = (idx + i) % n_examples
            mapped_idx = index_mapping[current_idx]
            template_idx = mapped_idx // len(value_combinations)
            combination_idx = mapped_idx % len(value_combinations)
            template = templates[template_idx]
            original_tokens = model.tokenizer.tokenize(template)
            example_attribute_order = re.findall(r'\{([^}]*)\}', template)

            def instantiate_template(current_template, combination, attribute_order, change_attr=None, new_value=None):
                attr_value_map = dict(zip(value_dictionary.keys(), combination))
                if new_value is None and attr_val_dict is not None:
                    attr, val = list(attr_val_dict.items())[0]
                    if attr != "order":
                        attr_value_map[attr] = val
                    elif name_order_pattern(example_attribute_order) != val:
                        template = templates[random.choice(counter_templates[name_order_pattern(example_attribute_order)])]
                        example_attribute_order = re.findall(r'\{([^}]*)\}', template)
                        current_template = template

                instantiated_template = current_template
                current_values = []
                for attribute in attribute_order:
                    value = new_value if attribute == change_attr else attr_value_map[attribute]
                    instantiated_template = instantiated_template.replace(f"{{{attribute}}}", str(value), 1)
                    current_values.append(value)
                return instantiated_template, current_values

            current_template, current_values = instantiate_template(
                template,
                value_combinations[combination_idx],
                example_attribute_order
            )

            instantiated_tokens = model.tokenizer.tokenize(current_template)
            value_positions = get_value_positions(instantiated_tokens, original_tokens)

            if get_counterexample:
                if counter_attr_specified:
                    # Use the specified attribute for counterexample generation
                    counter_attr_choice = counter_attr_specified
                else:
                    # Use balanced selection if no attribute is specified
                    if(list(counter_templates.keys()) != []):
                        counter_attr_choice = random.choice(
                            ["order"] + list(set(example_attribute_order))
                        )
                    else:
                        counter_attr_choice = random.choice(
                            list(set(example_attribute_order))
                        )
                    counterexample_counts[counter_attr_choice] += 1

                if counter_attr_choice == "order":
                    # Generate an order-based counterexample
                    available_templates = counter_templates[order_patterns[tuple(example_attribute_order)]]
                    counter_template_idx = random.choice([idx for idx in available_templates if idx != template_idx])
                    counter_template = templates[counter_template_idx]
                    counter_attribute_order = re.findall(r'\{([^}]*)\}', counter_template)
                    counterexample_template, counterexample_values = instantiate_template(
                        counter_template,
                        value_combinations[combination_idx],
                        counter_attribute_order
                    )
                else:
                    # Generate an attribute-based counterexample
                    counter_attribute_order = example_attribute_order
                    counter_attr = counter_attr_choice
                    current_value = dict(zip(example_attribute_order, current_values))[counter_attr]
                    counter_value = random.choice(
                        [v for v in value_dictionary[counter_attr] if v != current_value]
                    )
                    counter_template = template
                    counterexample_template, counterexample_values = instantiate_template(
                        counter_template,
                        value_combinations[combination_idx],
                        example_attribute_order,
                        change_attr=counter_attr,
                        new_value=counter_value
                    )

                original_counter_tokens = model.tokenizer.tokenize(counter_template)
                instantiated_counter_tokens = model.tokenizer.tokenize(counterexample_template)
                counter_value_positions = get_value_positions(instantiated_counter_tokens, original_counter_tokens)


                results.append({
                    "example": current_template,
                    "order": order_patterns[tuple(example_attribute_order)],
                    "value-positions": value_positions,
                    "values": current_values,
                    "attributes": example_attribute_order,  # Adding the attributes key
                    "counterexample": counterexample_template,
                    "counter-order": order_patterns[tuple(counter_attribute_order)],
                    "counter-value-positions": counter_value_positions,
                    "counter-values": counterexample_values,
                    "counter-attributes": counter_attribute_order,
                    "counter-attribute": counter_attr_choice
                })
            else:
                results.append({
                    "example": current_template,
                    "order": order_patterns[tuple(example_attribute_order)],
                    "value-positions": value_positions,
                    "values": current_values,
                    "attributes": example_attribute_order,  # Adding the attributes key
                    "counterexample": "",
                    "counter-order": "",
                    "counter-value-positions": [],
                    "counter-values": [],
                    "counter-attributes": [],
                    "counter-attribute": ""
                })

        return results if len(results) > 1 else results[0]


    return get_example, n_examples



    
def get_induction_sequence(model, cross_entropy_threshhold=0.5, seq_length=5, test_targets=10):
    mean_cross_entropy = float("inf")

    while (mean_cross_entropy > cross_entropy_threshhold) or (len(induction_tokens) != seq_length):
        cross_entropy = []

        induction_tokens = torch.randint(len(range(model.cfg.d_vocab)), (seq_length,)).tolist()
        induction_words = [model.tokenizer.decode(idx) for idx in induction_tokens]
        induction_string = " " + model.tokenizer.convert_tokens_to_string(induction_words).strip()
        induction_tokens = model.tokenizer.encode(induction_string, add_special_tokens=False)
        indcution_targets = [model.tokenizer.encode(" " + x.strip(), add_special_tokens=False) for x in model.tokenizer.decode(random.sample(range(model.cfg.d_vocab), test_targets))]

        for target_id in indcution_targets:
            induction_sequence = torch.tensor(induction_tokens + target_id + induction_tokens).unsqueeze(0)
            
            #print(model.generate(induction_string + model.tokenizer.decode(target_id) + induction_string, do_sample=False, max_new_tokens=1))
            logits = model(induction_sequence)

            probs = torch.nn.functional.softmax(logits[0, -1, :], dim=-1)
            correct_prob = probs[target_id[0]].item()
            cross_entropy.append(float(-torch.log(torch.tensor(correct_prob)).item()))
        
        mean_cross_entropy = sum(cross_entropy) / len(cross_entropy)
        
        if (mean_cross_entropy < cross_entropy_threshhold) and (len(induction_tokens) == seq_length):
            print(induction_string)
            return induction_string
        

def evaluate_example_generator(model, example_generator, target_attr, k, value_target=True, target_value=None):
    corr = []
    for i in range(k):
        exmaple_batches = example_generator(i)
        examples = [example["example"] for example in exmaple_batches]
        example_tokens = torch.cat([model.tokenizer.encode(e, return_tensors="pt") for e in examples],dim=0)
        target_ids = [example["attributes"].index(target_attr) for example in exmaple_batches]
        if(value_target):
            predictions = [model.tokenizer.decode(i).strip() for i in model(example_tokens)[:,-1].argmax(-1)]
            target_values = [example["values"][target_id] for example,target_id in zip(exmaple_batches,target_ids)]
        else:
            predictions = model(example_tokens)[:,-1].argmax(-1)
            target_values = [target_value for example in exmaple_batches]

        corr.extend([int(x==y) for x,y in zip(target_values,predictions)])

    return sum(corr)/len(corr)


def init_ioi(model, features, batch_size, train=True):

    if(train):
        get_ioi_examples, n_ioi_examples = init_example_generator(
            model, 
            templates=[
                "Then, {io} and {subject} had a long argument. {subject} gave a drink to",
                "Then, {io} and {subject} went to the store. {subject} gave an apple to",
                "Then, {subject} and {io} had a long argument. {subject} gave a drink to",
                "Then, {subject} and {io} went to the store. {subject} gave an apple to"
                ], 
            value_dictionary={
                "subject":features, 
                "io":features
            },
            batch_size=batch_size,
            shuffle=True
        )
    else:
        get_ioi_examples, n_ioi_examples = init_example_generator(
            model, 
            templates=[
                "Then, {io} and {subject} went to the cafe. {subject} gave the cake to",
                "Then, {subject} and {io} went to the cafe. {subject} gave the cake to"
            ], 
            value_dictionary={
                "subject": features, 
                "io": features
            },
            batch_size=batch_size,
            shuffle=True
        )
    
    return get_ioi_examples,n_ioi_examples


def init_ind(model, features, batch_size, cross_entropy_threshhold=0.5, train=True, seq_length=3):

    if(train):
        train_seq1 = get_induction_sequence(model, seq_length=seq_length, cross_entropy_threshhold=cross_entropy_threshhold)
        train_seq2 = get_induction_sequence(model, seq_length=seq_length, cross_entropy_threshhold=cross_entropy_threshhold)
        get_ind_examples, n_ind_examples = init_example_generator(
            model, 
            templates = [
                train_seq1 + " {ind2} , {ind1} , {ind2} , {ind1}" + train_seq1 + " {ind2} , {ind1} ,",  
                train_seq2 + " {ind2} , {ind1} , {ind2} , {ind1}" + train_seq2 + " {ind2} , {ind1} ,", 
                train_seq1 + " {ind1} , {ind1} , {ind2} , {ind2}" + train_seq1 + " {ind1} , {ind1} ,", 
                train_seq2 + " {ind1} , {ind1} , {ind2} , {ind2}" + train_seq2 + " {ind1} , {ind1} ,", 
            ],
            value_dictionary={
                "ind1":features,
                "ind2":features
            },
            batch_size=batch_size,
            shuffle=True
        )
    else:
        test_seq = get_induction_sequence(model, seq_length=seq_length, cross_entropy_threshhold=cross_entropy_threshhold)
        get_ind_examples, n_ind_examples = init_example_generator(
            model, 
            templates = [
                test_seq + " {ind2} , {ind1} , {ind2} , {ind1}" + test_seq + " {ind2} , {ind1} ,", 
                test_seq + " {ind1} , {ind1} , {ind2} , {ind2}" + test_seq + " {ind1} , {ind1} ,",
            ],
            value_dictionary={
                "ind1": features,
                "ind2": features
            },
            batch_size=batch_size,
            shuffle=True
        )
    
    return get_ind_examples,n_ind_examples
