"""
Data Generator for the Induction Heads Task 

This script provides a configurable function to generate datasets for the
Induction Heads task across a full spectrum of systematically increasing 
difficulty levels. This allows for a rigorous, diagnostic evaluation of 
sequence models, from basic recall to autonomous, dynamic learning.
"""

import torch
import random
import numpy as np

def torch_induction_heads_data(
    seq_len, 
    vocab_size, 
    num_induction_pairs, 
    batch_shape=(),
    difficulty_level=0, 
    max_noise_len=4,
    level_4_noise_type='none'
):
    """
    Generates a dataset for the Induction Heads task with varying difficulty.

    --- Difficulty Levels ---

    Level 0: Baseline
        - Description: Clean, adjacent triplets [P,A,B].
        - Structure: [P,A1,B1,P,A2,B2,...]
        - Tests: Basic selective recall with a clear trigger.

    Level 1: Memory Robustness
        - Description: Noise is added BETWEEN the [P,A,B] triplets.
        - Structure: [P,A1,B1, N,N, P,A2,B2,...]
        - Tests: Long-term memory retention against continuous distraction.

    Level 2: Abstract Pattern Recognition
        - Description: Noise is added WITHIN the [P,A,B] triplets.
        - Structure: [P,N,A1,N,B1, P,N,A2,N,B2,...]
        - Tests: Ability to learn a non-continuous, abstract association rule.

    Level 3: Combined Stress Test
        - Description: Noise is added both WITHIN and BETWEEN triplets.
        - Structure: [P,N,A1,B1,N,N, P,A2,N,B2,...]
        - Tests: Overall performance under combined memory and abstraction pressure.

    Level 4: Autonomous Learning Suite (No Prefix 'P')
        Level 4.0 (Sanity Check):
            - Config: level_4_noise_type='none'
            - Structure: [A1,B1,A2,B2,...]
            - Tests: Basic pattern learning without any explicit trigger.

        Level 4.1 (Robust Discovery):
            - Config: level_4_noise_type='between'
            - Structure: [A1,B1,N,N,A2,B2,...]
            - Tests: Unsupervised discovery of meaningful pairs in a noisy environment.

        Level 4.2 (Dynamic World Modeling):
            - Config: level_4_noise_type='conflict'
            - Structure: [A,B, N,N, ..., A,C, N,N, ...]
            - Tests: Unsupervised state updating and temporal reasoning (understanding that new information overrides old).
    """
    total_samples = int(np.prod(batch_shape))
    prefix_token = vocab_size - 1
    
    all_inputs = []
    all_targets = []

    for _ in range(total_samples):
        available_tokens = list(range(vocab_size - 1))
        current_num_pairs = min(len(available_tokens) // 2, num_induction_pairs)
        if current_num_pairs == 0: raise ValueError(f"Vocab size {vocab_size} too small.")
        
        random.shuffle(available_tokens)
        pairs = [(available_tokens[2*i], available_tokens[2*i+1]) for i in range(current_num_pairs)]

        sequence = []
        query_a, query_b = None, None

        if difficulty_level == 4:
            # --- LEVEL 4: Autonomous Learning Suite ---
            if level_4_noise_type == 'conflict':
                if current_num_pairs < 2: raise ValueError("Conflict task needs at least 2 pairs.")
                key_to_cover, old_value = pairs[0]
                available_for_new_val = [t for t in available_tokens if t not in [key_to_cover, old_value]]
                new_value = random.choice(available_for_new_val)
                
                sequence.extend([key_to_cover, old_value])
                sequence.extend([random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))])
                for a, b in pairs[1:]:
                    sequence.extend([a,b])
                    sequence.extend([random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))])
                sequence.extend([key_to_cover, new_value])
                sequence.extend([random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))])
                query_a, query_b = key_to_cover, new_value
            else:
                for a, b in pairs:
                    sequence.extend([a, b])
                    if level_4_noise_type == 'between':
                        sequence.extend([random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))])
                query_pair_index = random.randint(0, current_num_pairs - 1)
                query_a, query_b = pairs[query_pair_index]
        else:
            # --- Levels 0-3 with Prefix Token ---
            for a, b in pairs:
                triplet = [prefix_token]
                if difficulty_level in [2, 3]:
                    noise1 = [random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))]
                    noise2 = [random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))]
                    triplet.extend([*noise1, a, *noise2, b])
                else:
                    triplet.extend([a, b])
                sequence.extend(triplet)
                if difficulty_level in [1, 3]:
                    sequence.extend([random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))])
            query_pair_index = random.randint(0, current_num_pairs - 1)
            query_a, query_b = pairs[query_pair_index]

        # --- Construct query, pad, and create target ---
        query_sequence = []
        if difficulty_level != 4:
            query_sequence.append(prefix_token)
            if difficulty_level in [2, 3]:
                noise_query = [random.choice(available_tokens) for _ in range(np.random.randint(1, max_noise_len + 1))]
                query_sequence.extend([*noise_query, query_a])
            else:
                query_sequence.append(query_a)
        else:
            query_sequence.append(query_a)

        sequence.extend(query_sequence)
        prediction_index = len(sequence) - 1
        
        padding_needed = seq_len - len(sequence)
        if padding_needed < 0: sequence = sequence[:seq_len]
        else: sequence.extend([random.choice(available_tokens) for _ in range(padding_needed)])
        
        target = torch.full((seq_len,), -100, dtype=torch.long)
        if prediction_index < seq_len - 1:
            target[prediction_index] = query_b
            
        all_inputs.append(sequence)
        all_targets.append(target)
    
    x = torch.tensor(all_inputs, dtype=torch.long).view(*batch_shape, seq_len)
    y = torch.stack(all_targets).view(*batch_shape, seq_len)
    return x, y


def generate_dataset(dataset_config, training_config, seq_len=None):
    if seq_len is None:
        current_seq_len = dataset_config["train_seq_len"]
    else:
        current_seq_len = seq_len
    
    # Safely get all config values with defaults
    config_params = {
        "seq_len": current_seq_len,
        "vocab_size": dataset_config.get("vocab_size", 16),
        "num_induction_pairs": dataset_config.get("num_induction_pairs", 3),
        "batch_shape": (training_config["batch_size"],),
        "difficulty_level": dataset_config.get("difficulty_level", 0),
        "max_noise_len": dataset_config.get("max_noise_len", 4),
        "level_4_noise_type": dataset_config.get("level_4_noise_type", 'none')
    }
    
    x, y  = torch_induction_heads_data(**config_params)
    return x, y