
import numpy as np 
from tqdm import tqdm 
from collections import Counter
from numpy.lib.stride_tricks import sliding_window_view

class GeneratingData:
    """data class for target/context generation."""
    
    def __init__(self,chain_length:int):

        self.chain_length = chain_length
        

    def generate_chain_1(self, J=1):
        """
        Génère une chaîne de spins Ising 1D avec interactions de voisinage immédiat.
        
        Le modèle est assimilé à une chaîne de Markov :
        P(σ_{i+1}=+1 | σ_i) ∝ exp(J * σ_i * (+1))
        P(σ_{i+1}=-1 | σ_i) ∝ exp(J * σ_i * (-1))
        
        Renvoie un tableau de spins (+1 ou -1) de longueur 'length'.
        """
        chain = np.empty(self.chain_length, dtype=int)
        chain[0] = np.random.choice([1, -1])
        for i in range(1, self.chain_length):
            s_prev = chain[i-1]
            # Calcul de la probabilité conditionnelle pour σ = +1
            # On a p(σ=+1) = exp(J*s_prev) / [exp(J*s_prev) + exp(-J*s_prev)]
            p_plus = np.exp(J * s_prev) / (np.exp(J * s_prev) + np.exp(-J * s_prev))
            chain[i] = 1 if np.random.rand() < p_plus else -1
        chain1_bin = (chain == 1).astype(int)
        return chain1_bin
    
    def generate_chain1_with_context(self, J=1, lambda_param=0.5):
        """
        Génère simultanément :
        - La chaîne principale (spins) σ_i ∈ {-1, +1},
        - La chaîne de contexte c_i ∈ {0, 1, 2, 3}, tirée i.i.d. (par exemple, uniformément).
        
        Le modèle de transition pour la chaîne principale est défini par :
        
        P(σ_{i+1}=+1 | σ_i, c_{i+1}=c) = exp(J*σ_i + λ*(c-1.5))
            / ( exp(J*σ_i + λ*(c-1.5)) + exp(-J*σ_i - λ*(c-1.5)) ).
        
        Renvoie :
            chain_bin : tableau binaire (0 et 1) correspondant aux spins (+1 → 1, -1 → 0)
            context : tableau des contextes (valeurs 0, 1, 2, 3)
        """
        # Génération du contexte : tiré indépendamment parmi {0,1,2,3}
        context = np.random.choice([0,1], size=self.chain_length)
        
        # Initialisation de la chaîne principale
        chain = np.empty(self.chain_length, dtype=int)
        chain[0] = np.random.choice([1, -1])
        
        for i in range(1, self.chain_length):
            # On utilise le contexte à l'instant i pour conditionner la transition
            c = context[i]
            # Mapping : f(c) = c - 1.5, ce qui donne : -1.5, -0.5, 0.5, 1.5
            f_c = c - 1.5
            # Calcul de la probabilité de passer à +1
            num = np.exp(J * chain[i-1] + lambda_param * f_c)
            den = num + np.exp(-J * chain[i-1] - lambda_param * f_c)
            p_plus = num / den
            chain[i] = 1 if np.random.rand() < p_plus else -1
        
        # Conversion en binaire pour la chaîne principale (+1 -> 1, -1 -> 0)
        chain_bin = (chain == 1).astype(int)
        return chain_bin, context



    def generate_chain_2(self, block_size=400000):
        """
        Chaîne 2 : Chaîne Ising 1D avec interactions de voisinage unique
        dont le couplage J est réinitialisé toutes les 'block_size' positions.
        Pour chaque bloc, J est tiré d'une loi gaussienne de moyenne 0 et variance 1.  reset every 400,000 spins
        """
        chain = np.empty(self.chain_length, dtype=int)
        chain[0] = np.random.choice([1, -1])
        i = 0
        while i < self.chain_length - 1:
            # Tirage d'un J pour le bloc courant
            J = np.random.normal(loc=0.0, scale=1.0) #a random J for all the block size. We need to see the block to estimate this J. 
            block_end = min(i + block_size, self.chain_length)
            for j in range(i+1, block_end):
                s_prev = chain[j-1]
                p_plus = np.exp(J * s_prev) / (np.exp(J * s_prev) + np.exp(-J * s_prev))
                chain[j] = 1 if np.random.rand() < p_plus else -1
            i = block_end - 1  # Continuer à partir du dernier spin du bloc
        chain1_bin = (chain == 1).astype(int)
        return chain1_bin
    
    
    def generate_chain_2_with_context(self, block_size=400000, lambda_param=0.5):
        """
        Adaptation de generate_chain_2 pour incorporer un contexte.
        
        La chaîne principale (spins) est générée par blocs, avec un couplage J réinitialisé aléatoirement
        pour chaque bloc (J tiré d'une loi gaussienne de moyenne 0 et écart type 1).
        De plus, pour chaque position, un contexte c est généré i.i.d. parmi {0, 1, 2, 3}.
        
        La transition est alors conditionnée par le contexte selon :
        
            P(σ_{j}=+1 | σ_{j-1}, c_j) = exp(J*σ_{j-1} + λ*(c_j-1.5))
                                        / [ exp(J*σ_{j-1} + λ*(c_j-1.5)) + exp(-J*σ_{j-1} - λ*(c_j-1.5)) ]
        
        Renvoie :
        chain_bin : tableau binaire (0 pour -1, 1 pour +1) pour la chaîne principale.
        context   : tableau des contextes (valeurs 0,1,2,3).
        """
        chain = np.empty(self.chain_length, dtype=int)
        chain[0] = np.random.choice([1, -1])
        
        # Génération du contexte : tiré indépendamment parmi {0,1,2,3}
        context = np.random.choice([0, 1], size=self.chain_length)
        
        # Pour chaque bloc, on réinitialise J et on génère les spins conditionnés par le contexte
        for i in range(0, self.chain_length - 1, block_size):
            # Tirage d'un J pour le bloc courant
            J = np.random.normal(loc=0.0, scale=1.0)
            block_end = min(i + block_size, self.chain_length)
            for j in range(i + 1, block_end):
                s_prev = chain[j - 1]
                c = context[j]
                # Mapping : f(c)=c-1.5 (donne -1.5, -0.5, 0.5, 1.5)
                f_c = c - 1.5
                num = np.exp(J * s_prev + lambda_param * f_c)
                den = num + np.exp(-J * s_prev - lambda_param * f_c)
                p_plus = num / den
                chain[j] = 1 if np.random.rand() < p_plus else -1
        
        # Conversion de la chaîne principale en binaire : +1 → 1, -1 → 0
        chain_bin = (chain == 1).astype(int)
        return chain_bin, context


    def generate_chain_3(self,chain_length, block_size=200, equil_sweeps=100):
        """
        Chaîne 3 : Chaîne Ising 1D avec interactions longues portées.
        
        Pour chaque bloc de 'block_size' spins, on crée la matrice symétrique J
        de dimension LxL (L = block_size ou le reste), où pour i < j, 
            J[i, j] ~ N(0, (1/(j-i))^2)
        On utilise une version vectorisée pour construire J et mettre à jour les spins
        via l'algorithme de Metropolis.
        """
        chain = []
        num_blocks = chain_length // block_size
        
        for _ in range(num_blocks):
            L = block_size
            # Construction vectorisée de la matrice J
            i_idx, j_idx = np.triu_indices(L, k=1)
            diffs = j_idx - i_idx
            J_vals = np.random.normal(loc=0.0, scale=1.0/diffs)
            J = np.zeros((L, L))
            J[i_idx, j_idx] = J_vals
            J[j_idx, i_idx] = J_vals
            
            # Initialisation aléatoire du bloc de spins
            block = np.random.choice([1, -1], size=L)
            
            # Équilibration vectorisée par Metropolis
            for _ in range(equil_sweeps):
                local_field = np.dot(J, block)
                dE = 2 * block * local_field
                # Critère d'acceptation vectorisé
                accept = (dE <= 0) | (np.random.rand(L) < np.exp(-dE))
                block[accept] = -block[accept]
            
            chain.extend(block.tolist())
        
        # Gestion des spins restants (si length n'est pas multiple de block_size)
        leftover = chain_length - num_blocks * block_size
        if leftover > 0:
            L = leftover
            i_idx, j_idx = np.triu_indices(L, k=1)
            diffs = j_idx - i_idx
            J_vals = np.random.normal(loc=0.0, scale=1.0/diffs)
            J = np.zeros((L, L))
            J[i_idx, j_idx] = J_vals
            J[j_idx, i_idx] = J_vals
            
            block = np.random.choice([1, -1], size=L)
            for _ in range(equil_sweeps):
                local_field = np.dot(J, block)
                dE = 2 * block * local_field
                accept = (dE <= 0) | (np.random.rand(L) < np.exp(-dE))
                block[accept] = -block[accept]
            chain.extend(block.tolist())
        #chain1_bin = (chain == 1).astype(int)
        return chain
        

#--- very basics ways to estimate entropy for a word here - and the conditionnal entropy. 

def compute_word_entropy(chain, word_length):
    """
    Extrait tous les mots (sous-séquences continues) de longueur 'word_length'
    à partir de la chaîne (tableau d'entiers, ici 0 et 1) et calcule 
    l'entropie de leur distribution empirique en bits.
    
    La formule utilisée est : 
      S(N) = - Σ p(W) log2(p(W))
    """
    L = len(chain)
    if L < word_length:
        return 0.0
    freq = {}
    for i in tqdm(range(L - word_length + 1)):
        word = tuple(chain[i:i+word_length])
        freq[word] = freq.get(word, 0) + 1
    total = sum(freq.values())
    entropy = 0.0
    for count in freq.values():
        p = count / total
        entropy -= p * np.log2(p)
    return entropy


def compute_conditional_word_entropy_fast(chain, context, word_length):
    """
    Calcule l'entropie conditionnelle H(W|C) en bits, où
      - W est un mot (de longueur word_length) extrait de la séquence 'chain'
      - C est le mot de contexte correspondant extrait de la séquence 'context'
      
    La formule utilisée est :
      H(W|C) = - sum_{c in C} P(c) sum_{w in W} P(w|c) log2(P(w|c))
      
    Cette version accélérée utilise collections.Counter pour compter efficacement les fréquences.
    
    Args:
        chain (array-like): séquence principale (par exemple, tableau de 0 et 1)
        context (array-like): séquence de contexte (doit avoir la même longueur que chain)
        word_length (int): longueur des mots à considérer
        
    Returns:
        cond_entropy (float): l'entropie conditionnelle en bits.
    """
    L = len(chain)
    if L < word_length:
        return 0.0

    # Extraire toutes les fenêtres pour le mot et pour le contexte
    context_windows = [tuple(context[i:i+word_length]) for i in range(L - word_length + 1)]
    word_windows = [tuple(chain[i:i+word_length]) for i in range(L - word_length + 1)]
    
    # Utiliser Counter pour compter les occurrences
    context_counter = Counter(context_windows)
    joint_counter = Counter(zip(context_windows, word_windows))
    
    total_contexts = sum(context_counter.values())
    cond_entropy = 0.0

    # Parcourir chaque contexte unique et calculer l'entropie conditionnelle pour ce contexte
    for ctxt, cnt_ctxt in context_counter.items():
        p_ctxt = cnt_ctxt / total_contexts
        entropy_ctxt = 0.0
        # Pour tous les mots qui apparaissent avec ce contexte
        for (ct_key, w), joint_count in joint_counter.items():
            if ct_key == ctxt:
                p_w_given_ctxt = joint_count / cnt_ctxt
                entropy_ctxt -= p_w_given_ctxt * np.log2(p_w_given_ctxt)
        cond_entropy += p_ctxt * entropy_ctxt

    return cond_entropy


def compute_conditional_word_entropy_vectorized(chain, context, word_length):
    """
    Calcule l'entropie conditionnelle H(W|C) en bits, où
      - W est un mot (de longueur word_length) extrait de la séquence 'chain'
      - C est le mot de contexte correspondant extrait de la séquence 'context'
    
    La formule utilisée est :
      H(W|C) = - ∑_c P(c) ∑_w P(w|c) log2(P(w|c))
    
    Cette version vectorisée utilise sliding_window_view et np.unique pour éviter
    une double boucle sur l'ensemble des fenêtres.
    
    Args:
        chain (array-like): séquence principale (par exemple, tableau de 0 et 1)
        context (array-like): séquence de contexte (doit avoir la même longueur que chain)
        word_length (int): longueur des mots à considérer
        
    Returns:
        cond_entropy (float): l'entropie conditionnelle en bits.
    """
    L = len(chain)
    if L < word_length:
        return 0.0
    
    # Obtenir toutes les fenêtres (de taille word_length) pour la chaîne et le contexte
    chain_windows = sliding_window_view(chain, window_shape=word_length)  # forme: (L-word_length+1, word_length)
    context_windows = sliding_window_view(context, window_shape=word_length)  # même forme
    
    n_windows = chain_windows.shape[0]
    
    # Pour pouvoir compter rapidement les fenêtres uniques, on convertit chaque ligne en un blob binaire.
    # La taille d'un blob est le nombre d'octets par élément * word_length.
    dt_chain = np.dtype((np.void, chain_windows.dtype.itemsize * word_length))
    dt_context = np.dtype((np.void, context_windows.dtype.itemsize * word_length))
    
    # Vues contiguës
    chain_view = np.ascontiguousarray(chain_windows).view(dt_chain).reshape(-1)
    context_view = np.ascontiguousarray(context_windows).view(dt_context).reshape(-1)
    
    # Comptage des contextes uniques
    uniq_context, inv_context, counts_context = np.unique(context_view, return_inverse=True, return_counts=True)
    
    # Comptage conjoint (contexte, mot)
    joint = np.array([context_view, chain_view]).T  # forme (n_windows, 2)
    # Pour joindre les deux, on concatène les deux blobs en un seul
    dt_joint = np.dtype((np.void, context_windows.dtype.itemsize * word_length + chain_windows.dtype.itemsize * word_length))
    joint_view = np.ascontiguousarray(np.hstack((context_windows, chain_windows))).view(dt_joint).reshape(-1)
    uniq_joint, inv_joint, counts_joint = np.unique(joint_view, return_inverse=True, return_counts=True)
    
    # Calcul de l'entropie conditionnelle
    cond_entropy = 0.0
    # On boucle sur les indices des contextes uniques – le nombre de contextes uniques sera beaucoup plus petit.
    for idx, cnt in tqdm(enumerate(counts_context)):
        p_ctxt = cnt / n_windows
        # Sélection des indices pour lesquels le contexte est uniq_context[idx]
        mask = (inv_context == idx)
        # Parmi ces indices, on regarde les valeurs jointes (qui correspondent aux p(w|c))
        joint_vals = inv_joint[mask]
        uniq_joint_vals, joint_counts = np.unique(joint_vals, return_counts=True)
        p_w_given_ctxt = joint_counts / cnt  # probabilités conditionnelles
        entropy_ctxt = -np.sum(p_w_given_ctxt * np.log2(p_w_given_ctxt))
        cond_entropy += p_ctxt * entropy_ctxt

    return cond_entropy