"""
Memory-efficient PRNG data generation utilities.

This module provides on-demand sequence generation that only stores PRNG parameters (a, c, x_0)
initially and generates full sequences during training. This significantly reduces
memory usage for large datasets while maintaining identical training behavior.
"""

import time
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import argparse
from typing import List, Dict, Tuple, Optional
from utils.prng_data import (
    find_as, find_coprimes,
    base_b_lcg, base_tlcg, base_b_pcg_rs, base_b_pcg_rr, base_b_pcg_xsh_rr,
    base_b_pcg_xsh_rs, base_b_pcg_xsl_rr, convert_to_base_b
)


# Validation functions
def validate_bits_to_keep(m: int, bits_to_keep: int) -> bool:
    """Validate bits_to_keep parameter for truncated generators"""
    if bits_to_keep is None:
        return False
    bit_length = int(np.ceil(np.log2(m)))
    return bits_to_keep > 0 and bits_to_keep <= bit_length


def validate_pcg_rs_constraints(m: int, control_bits: int, bits_to_keep: int) -> bool:
    if control_bits is None or bits_to_keep is None:
        return False
    bit_length = int(np.ceil(np.log2(m)))
    return bit_length > (2 ** control_bits - 1) + bits_to_keep


def validate_pcg_rr_constraints(m: int, control_bits: int, bits_to_keep: int) -> bool:
    if control_bits is None or bits_to_keep is None:
        return False
    return bits_to_keep >= 2 ** control_bits


def validate_pcg_xsh_rr_constraints(m: int, control_bits: int, bits_to_keep: int) -> bool:
    if control_bits is None or bits_to_keep is None:
        return False
    bit_length = int(np.ceil(np.log2(m)))
    return bit_length > control_bits + bits_to_keep


def validate_pcg_xsh_rs_constraints(control_bits: int, bits_to_keep: int) -> bool:
    if control_bits is None or bits_to_keep is None:
        return False
    constant_shift = bits_to_keep - control_bits - 2 ** control_bits + 1
    return constant_shift > 0


def validate_pcg_rxs_m_xs_8_8_constraints(m: int) -> bool:
    """Validate PCG_RXS_M_XS_8_8 constraints: m == 2**8"""
    return m == 2**8


def validate_pcg_rxs_m_xs_16_16_constraints(m: int) -> bool:
    """Validate PCG_RXS_M_XS_16_16 constraints: m == 2**16"""
    return m == 2**16


def validate_prng_parameters(prng_type: str, **kwargs) -> tuple[bool, str]:
    """
    Validate parameters for a specific PRNG type.
    Returns (is_valid, error_message)
    """
    m = kwargs.get('m')
    control_bits = kwargs.get('control_bits')
    bits_to_keep = kwargs.get('bits_to_keep')
    
    # Special validation for LCG: bits_to_keep should equal the bit length of m
    if prng_type == 'lcg':
        if m is not None and bits_to_keep is not None:
            expected_bits_to_keep = int(np.ceil(np.log2(m)))
            if bits_to_keep != expected_bits_to_keep:
                return False, f"LCG constraint: bits_to_keep ({bits_to_keep}) should equal bit length of m ({expected_bits_to_keep})"
        if control_bits is not None and control_bits != 0:
            return False, f"LCG constraint: control_bits ({control_bits}) should be 0"
    # TLCG does not use control_bits
    if prng_type == 'truncated_lcg':
        if control_bits is not None and control_bits != 0:
            return False, f"TLCG constraint: control_bits ({control_bits}) should be 0"
    
    # Check bits_to_keep constraint for all types that use it (except LCG and fixed-output types)
    elif prng_type not in ['pcg_rxs_m_xs_8_8', 'pcg_rxs_m_xs_16_16']:
        if m is not None and bits_to_keep is not None and not validate_bits_to_keep(m, bits_to_keep):
            bit_length = int(np.ceil(np.log2(m)))
            return False, f"Invalid bits_to_keep: must be 0 < {bits_to_keep} <= {bit_length}"
    
    # Then check type-specific constraints for PCG variants
    if prng_type == 'pcg_rs':
        if m is not None and control_bits is not None and bits_to_keep is not None:
            if not validate_pcg_rs_constraints(m, control_bits, bits_to_keep):
                bit_length = int(np.ceil(np.log2(m)))
                required = (2 ** control_bits - 1) + bits_to_keep
                return False, f"PCG_RS constraint violated: bit_length ({bit_length}) must be > {required}"
    
    elif prng_type == 'pcg_rr':
        if control_bits is not None and bits_to_keep is not None:
            if not validate_pcg_rr_constraints(m, control_bits, bits_to_keep):
                required = 2 ** control_bits
                return False, f"PCG_RR constraint violated: bits_to_keep ({bits_to_keep}) must be >= {required}"
    
    elif prng_type == 'pcg_xsh_rr':
        if m is not None and control_bits is not None and bits_to_keep is not None:
            if not validate_pcg_xsh_rr_constraints(m, control_bits, bits_to_keep):
                bit_length = int(np.ceil(np.log2(m)))
                required = control_bits + bits_to_keep
                return False, f"PCG_XSH_RR constraint violated: bit_length ({bit_length}) must be > {required}"
    
    elif prng_type == 'pcg_xsh_rs':
        if control_bits is not None and bits_to_keep is not None:
            if not validate_pcg_xsh_rs_constraints(control_bits, bits_to_keep):
                constant_shift = bits_to_keep - control_bits - 2 ** control_bits + 1
                return False, f"PCG_XSH_RS constraint violated: constant_shift ({constant_shift}) must be > 0"
    
    elif prng_type == 'pcg_rxs_m_xs_8_8':
        if m is not None and not validate_pcg_rxs_m_xs_8_8_constraints(m):
            return False, f"PCG_RXS_M_XS_8_8 constraint violated: m must be 256, got {m}"
    
    elif prng_type == 'pcg_rxs_m_xs_16_16':
        if m is not None and not validate_pcg_rxs_m_xs_16_16_constraints(m):
            return False, f"PCG_RXS_M_XS_16_16 constraint violated: m must be 65536, got {m}"
    
    return True, ""


def generate_param_sets(config, rng, master_process=True) -> Tuple[List[Dict], List[Dict], List[int], List[int], List[int], List[int]]:
    """
    Generate only the parameter sets (a, c, x0) without creating full sequences.
    This is the memory-efficient alternative to generate_data().
    
    Args:
        config: Configuration object
        rng: Random number generator
        master_process: Whether this is the master process (for logging)
        
    Returns:
        Tuple of (train_param_sets, test_param_sets, train_a, train_c, val_a, val_c)
        where param_sets are lists of dicts containing parameters, and a/c lists
        are for API compatibility
    """
    t0 = time.time()
    
    # Parse control_bits
    if isinstance(config.control_bits, str) and ',' in config.control_bits:
        control_bits_list = [int(x.strip()) for x in config.control_bits.split(',')]
    else:
        control_bits_list = [int(config.control_bits)]
    
    # Handle multiple types
    if hasattr(config, 'type_list'):
        types_to_process = config.type_list
        if master_process:
            print("="*80)
            print(f"GENERATING PARAM SETS: Multiple types {'+'.join(types_to_process)} with m={config.m}")
            print(f"Control bits: {control_bits_list}")
            print("="*80)
    else:
        types_to_process = [config.type]
        if master_process:
            print("="*80)
            print(f"GENERATING PARAM SETS: {config.type} with m={config.m}")
            print(f"Control bits: {control_bits_list}")
            print("="*80)
    
    # Generate a and c values
    a_list = find_as(config.m, rng=rng, num=config.n_a+config.n_test_a)
    c_list = find_coprimes(config.m, rng=rng, num=config.n_c+config.n_test_c)
    assert len(a_list) >= config.n_a+config.n_test_a, "not enough a values"
    assert len(c_list) >= config.n_c+config.n_test_c, "not enough c values"
    
    train_a, val_a = a_list[:config.n_a], a_list[config.n_a:]
    train_c, val_c = c_list[:config.n_c], c_list[config.n_c:]
    
    # Generate parameter sets for all types and control_bits combinations
    train_param_sets = []
    test_param_sets = []
    
    for current_type in types_to_process:
        if master_process:
            print(f"Processing type: {current_type}")
        
        # Common parameters
        common_params = {
            'prng_type': current_type.lower(),
            'm': config.m,
            'seq_len': config.seq_len,
            'base': config.base,
            'digits': config.digits
        }
        
        # Add bits_to_keep for types that use it
        if current_type not in ['RXSMXS1616', 'LCG']:
            common_params['bits_to_keep'] = config.bits_to_keep
        
        # Handle control_bits for PCG variants
        if current_type in ['RS', 'RR', 'XSHRR', 'XSHRS', 'XSLRR']:
            for control_bits_val in control_bits_list:
                params_with_cb = {**common_params, 'control_bits': control_bits_val}
                
                # Generate parameter sets for all (a,c) combinations
                for a in train_a:
                    for c in train_c:
                        for _ in range(config.n_example):
                            param_set = {**params_with_cb, 'a': a, 'c': c}
                            train_param_sets.append(param_set)
                
                # Generate test parameter sets
                for a in val_a:
                    for c in val_c:
                        param_set = {**params_with_cb, 'a': a, 'c': c}
                        test_param_sets.append(param_set)
        
        elif current_type == 'XSPCGs':
            # Handle XSPCGs (xorshift PCG variants)
            for control_bits_val in control_bits_list:
                # Define xorshift PCG types
                pcg_types = ['xslrr', 'xshrr']
                
                # Check if PCG_XSH_RS constraint is satisfied before including it
                if validate_pcg_xsh_rs_constraints(control_bits_val, config.bits_to_keep):
                    pcg_types.append('xshrs')
                else:
                    if master_process:
                        constant_shift = config.bits_to_keep - control_bits_val - 2 ** control_bits_val + 1
                        print(f"Warning: Skipping PCG_XSH_RS for control_bits={control_bits_val} - constraint not satisfied")
                
                # Add xorshift PCG parameter sets for this control_bits value
                for pcg_type in pcg_types:
                    pcg_params = {**common_params, 'control_bits': control_bits_val}
                    pcg_params['prng_type'] = pcg_type
                    
                    # Generate parameter sets for all (a,c) combinations
                    for a in train_a:
                        for c in train_c:
                            for _ in range(config.n_example):
                                param_set = {**pcg_params, 'a': a, 'c': c}
                                train_param_sets.append(param_set)
                    
                    # Generate test parameter sets
                    for a in val_a:
                        for c in val_c:
                            param_set = {**pcg_params, 'a': a, 'c': c}
                            test_param_sets.append(param_set)
        
        else:
            # Simple case: LCG, TLCG, and fixed-output PCG variants
            if current_type == 'LCG':
                # Remove control_bits and bits_to_keep for LCG
                lcg_params = {k: v for k, v in common_params.items() 
                             if k not in ['bits_to_keep', 'control_bits']}
                params_to_use = lcg_params
            else:
                params_to_use = common_params
            
            # Generate parameter sets
            for a in train_a:
                for c in train_c:
                    for _ in range(config.n_example):
                        param_set = {**params_to_use, 'a': a, 'c': c}
                        train_param_sets.append(param_set)
            
            # Generate test parameter sets
            for a in val_a:
                for c in val_c:
                    param_set = {**params_to_use, 'a': a, 'c': c}
                    test_param_sets.append(param_set)
    
    t1 = time.time()
    if master_process:
        print("-"*80)
        print(f"PARAM SET GENERATION COMPLETE:")
        print(f"  - Train param sets: {len(train_param_sets)}")
        print(f"  - Test param sets: {len(test_param_sets)}")
        print(f"  - Time taken: {t1-t0:.2f} seconds")
        print(f"  - Memory saved: ~{(len(train_param_sets) + len(test_param_sets)) * config.seq_len * 8 / 1024**2:.1f} MB")
        print("-"*80)
    
    return train_param_sets, test_param_sets, train_a, train_c, val_a, val_c


class ParameterBasedPRNGDataset(Dataset):
    """
    Dataset that generates PRNG sequences on-demand from stored parameters.
    Uses smart caching: pre-generates sequences for first epoch, then uses on-demand for subsequent epochs.
    """
    
    def __init__(self, param_sets: List[Dict], pre_generate_first_epoch: bool = True):
        """
        Args:
            param_sets: List of parameter dictionaries, each containing
                       the parameters needed to generate a sequence
            pre_generate_first_epoch: Whether to pre-generate sequences for the first epoch
        """
        self.param_sets = param_sets
        self.pre_generate_first_epoch = pre_generate_first_epoch
        self._sequences = None  # Will store pre-generated sequences
        self._access_count = 0
        
        # Type mapping for PRNG functions
        self.type_mapping = {
            'lcg': base_b_lcg,
            'truncated_lcg': base_tlcg, 
            'rs': base_b_pcg_rs,
            'rr': base_b_pcg_rr,
            'xshrr': base_b_pcg_xsh_rr,
            'xshrs': base_b_pcg_xsh_rs,
            'xslrr': base_b_pcg_xsl_rr,
        }
    
    def __len__(self) -> int:
        return len(self.param_sets)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate and return sequence for given index"""
        if idx >= len(self.param_sets) or idx < 0:
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.param_sets)}")
        
        # If we have pre-generated sequences, use them
        if self._sequences is not None and isinstance(self._sequences, np.ndarray):
            sequence = self._sequences[idx]
            x = torch.tensor(sequence[:-1], dtype=torch.long)
            y = torch.tensor(sequence[1:], dtype=torch.long)
            return x, y
        
        # Otherwise, generate on-demand
        param_set = self.param_sets[idx]
        sequence = self._generate_sequence_from_params(param_set, idx)
        
        # Convert to input/target tensors
        x = torch.tensor(sequence[:-1], dtype=torch.long)
        y = torch.tensor(sequence[1:], dtype=torch.long)
        
        # If this is the first epoch and we should pre-generate, start collecting sequences
        if self.pre_generate_first_epoch and self._access_count < len(self.param_sets):
            if self._sequences is None:
                self._sequences = [None] * len(self.param_sets)
            self._sequences[idx] = sequence.copy()
            self._access_count += 1
            
            # If we've accessed all sequences, convert to numpy array for efficiency
            if self._access_count == len(self.param_sets):
                self._sequences = np.array(self._sequences)
                print(f"✅ Pre-generated all {len(self.param_sets)} sequences")
        
        return x, y
    
    def _generate_sequence_from_params(self, param_set: Dict, idx: int) -> np.ndarray:
        """Generate sequence from parameter dictionary"""
        prng_type = param_set['prng_type']
        prng_func = self.type_mapping.get(prng_type)
        
        if prng_func is None:
            raise ValueError(f"Unknown PRNG type: {prng_type}")
        
        # Create deterministic seed from index and parameters
        # This ensures same sequence for same index across calls
        seed_base = hash((idx, param_set['a'], param_set['c'], param_set['m'])) % (2**31)
        rng = np.random.default_rng(seed_base)
        
        # Generate x0
        x0 = rng.integers(1, param_set['m'])
        
        # Extract function parameters
        func_params = {k: v for k, v in param_set.items() 
                      if k not in ['prng_type', 'a', 'c']}
        
        # Generate sequence
        try:
            # The PRNG functions have different signatures, need to call them properly
            if prng_type == 'lcg':
                # base_b_lcg(m, seq_len, a, c, base, digits, num_examples, rng)
                sequence = prng_func(
                    m=param_set['m'],
                    seq_len=param_set['seq_len'],
                    a=param_set['a'],
                    c=param_set['c'],
                    base=param_set['base'],
                    digits=param_set['digits'],
                    num_examples=1,
                    rng=rng
                )
                # If num_examples=1, convert_to_base_b returns shape (1, seq_len), so take [0]
                if sequence.ndim > 1 and sequence.shape[0] == 1:
                    sequence = sequence[0]
            else:
                # PCG variants: func(m, seq_len, a, c, base, digits, control_bits, bits_to_keep, rng)
                # Note: PCG functions don't take x0 - they generate it internally
                sequence = prng_func(
                    m=param_set['m'],
                    seq_len=param_set['seq_len'],
                    a=param_set['a'],
                    c=param_set['c'],
                    base=param_set['base'],
                    digits=param_set['digits'],
                    control_bits=param_set.get('control_bits', 0),
                    bits_to_keep=param_set.get('bits_to_keep', 8),
                    num_examples=1,
                    rng=rng
                )
                # If num_examples=1, convert_to_base_b returns shape (1, seq_len), so take [0]
                if sequence.ndim > 1 and sequence.shape[0] == 1:
                    sequence = sequence[0]
            
            return sequence
            
        except Exception as e:
            raise RuntimeError(f"Error generating sequence for index {idx} with params {param_set}: {e}")


def generate_lowmem_data(config, rng, master_process=True, cache_size: int = 1000) -> Tuple[Dataset, Dataset, List[int], List[int], List[int], List[int]]:
    """
    Memory-efficient replacement for generate_data() that uses on-demand sequence generation.
    
    Args:
        config: Configuration object
        rng: Random number generator
        master_process: Whether this is the master process
        cache_size: Number of sequences to cache per dataset (ignored)
        
    Returns:
        Tuple of (train_dataset, test_dataset, train_a, train_c, val_a, val_c)
    """
    t0 = time.time()
    
    # Generate parameter sets and get a/c values in one call
    train_param_sets, test_param_sets, train_a, train_c, val_a, val_c = generate_param_sets(config, rng, master_process)
    
    # Create lowmem datasets with smart caching
    train_dataset = ParameterBasedPRNGDataset(train_param_sets, pre_generate_first_epoch=True)
    test_dataset = ParameterBasedPRNGDataset(test_param_sets, pre_generate_first_epoch=False)  # No need for test
    
    t1 = time.time()
    if master_process:
        print("-"*80)
        print(f"LOWMEM DATA GENERATION COMPLETE:")
        print(f"  - Train dataset size: {len(train_dataset)} sequences")
        print(f"  - Test dataset size: {len(test_dataset)} sequences")
        print(f"  - Time taken: {t1-t0:.2f} seconds")
        print(f"  - Memory usage: Parameters only")
        print(f"  - Smart caching: Pre-generates first epoch")
        print("-"*80)
    
    return train_dataset, test_dataset, train_a, train_c, val_a, val_c


def create_multi_modulus_efficient_datasets(moduli_configs: List, base_rng_seed: int = 97, 
                                           cache_size: int = 1000) -> Tuple[List[Dataset], List[Dataset]]:
    """
    Create memory-efficient datasets for multiple moduli configurations.
    
    Args:
        moduli_configs: List of configuration objects, one per modulus
        base_rng_seed: Base seed for RNG
        cache_size: Cache size per dataset (ignored)
        
    Returns:
        Tuple of (train_datasets, test_datasets) lists
    """
    train_datasets = []
    test_datasets = []
    
    for i, config in enumerate(moduli_configs):
        rng = np.random.default_rng(base_rng_seed + i * 1000)
        train_dataset, test_dataset, _, _, _, _ = generate_efficient_data(
            config, rng, master_process=False, cache_size=cache_size
        )
        train_datasets.append(train_dataset)
        test_datasets.append(test_dataset)
    
    return train_datasets, test_datasets


def create_curriculum_lowmem_datasets(config, master_process: bool = True, 
                                        cache_size: int = 1000, ddp: bool = False, 
                                        rank: int = None, world_size: int = None,
                                        num_workers: int = 4) -> Tuple[List[Dataset], List[DataLoader], List[List[int]], List[List[int]]]:
    """
    Create memory-efficient datasets for curriculum learning.
    
    Args:
        config: Configuration object with moduli list and other parameters
        master_process: Whether this is the master process
        cache_size: Cache size per modulus dataset (ignored)
        ddp: Whether using distributed data parallel
        rank: Process rank for distributed training
        world_size: Total number of processes for distributed training
        num_workers: Number of DataLoader workers
        
    Returns:
        Tuple of (train_datasets, test_loaders, train_a_values, train_c_values) where:
        - train_datasets are efficient datasets
        - test_loaders are DataLoader instances
        - train_a_values is a list of training a values for each modulus
        - train_c_values is a list of training c values for each modulus
    """
    train_datasets = []
    test_loaders = []
    train_a_values = []
    train_c_values = []
    
    for i, m in enumerate(config.moduli):
        if master_process:
            print(f"Creating dataset for modulus {m} ({i+1}/{len(config.moduli)})")
        
        # Create config for this modulus (same logic as original)
        config_m = argparse.Namespace(**vars(config))
        config_m.m = m
        
        # Set bits_to_keep
        if hasattr(config, 'moduli_bits_to_keep') and config.moduli_bits_to_keep is not None:
            if i < len(config.moduli_bits_to_keep):
                config_m.bits_to_keep = config.moduli_bits_to_keep[i]
            else:
                config_m.bits_to_keep = int(np.ceil(np.log2(m)))
        else:
            config_m.bits_to_keep = int(np.ceil(np.log2(m)))
        
        # Generate dataset with different seed for each modulus
        rng = np.random.default_rng(config.data_seed + i * 1000)
        train_dataset, test_dataset, train_a, train_c, _, _ = generate_lowmem_data(
            config_m, rng, master_process=False, cache_size=cache_size
        )
        
        train_datasets.append(train_dataset)
        train_a_values.append(train_a)
        train_c_values.append(train_c)
        
        # Create test loader
        if ddp:
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True
            )
            test_loader = DataLoader(
                test_dataset,
                batch_size=config.batch_size,
                sampler=test_sampler,
                pin_memory=True,
                drop_last=True,
                num_workers=num_workers,
                prefetch_factor=2
            )
        else:
            test_loader = DataLoader(
                test_dataset,
                batch_size=config.batch_size,
                shuffle=False,
                pin_memory=True,
                drop_last=True,
                num_workers=num_workers,
                prefetch_factor=2
            )
        
        test_loaders.append(test_loader)
    
    if master_process:
        total_train = sum(len(ds) for ds in train_datasets)
        total_test = sum(len(loader.dataset) for loader in test_loaders)
        print("Curriculum datasets created:")
        print(f"  - {len(config.moduli)} moduli")
        print(f"  - Total train sequences: {total_train}")
        print(f"  - Total test sequences: {total_test}")
        print(f"  - Memory saved: ~{total_train * config.seq_len * 8 / 1024**2:.1f} MB")
        print(f"  - Smart caching: Pre-generates first epoch")
    
    return train_datasets, test_loaders, train_a_values, train_c_values
