import torch
import numpy as np
import re
from typing import Dict, List, Tuple, Any
from transformers import AutoTokenizer

def icl_markov_evaluation(model, tokenizer, full_series: str, target_state_space: List[int], 
                         target_p: np.ndarray, interference_state_space: List[int], 
                         interference_p: np.ndarray, step_size: int = 2, smooth_window: int = 40) -> Dict[str, Any]:

    np.random.seed(42)
    torch.manual_seed(42)

    numeric_sequence, chain_labels = _parse_full_series(full_series)
    
    if len(numeric_sequence) == 0:
        raise ValueError("no valid state")


    target_tokens = []
    for state in target_state_space:
        spaced_token = f" {state}"
        token_id = tokenizer.convert_tokens_to_ids(spaced_token)
        
        if token_id is None or token_id == tokenizer.unk_token_id:
            encoded = tokenizer.encode(f" {state}", add_special_tokens=False)
            if len(encoded) == 1:
                token_id = encoded[0]
            else:
                encoded = tokenizer.encode(str(state), add_special_tokens=False)
                if len(encoded) == 1:
                    token_id = encoded[0]
                else:
                    raise ValueError(f" state {state} cannot be mapped")
        
        target_tokens.append(token_id)

    sequence_length = len(numeric_sequence)
    sampled_positions = list(range(0, sequence_length, step_size))

    if sampled_positions[-1] < sequence_length - 1:
        sampled_positions.append(sequence_length - 1)
    
    distance_array = np.zeros(len(sampled_positions))
    is_target_position = np.array([label == 'target' for label in chain_labels])

    print(f"Target P shape: {target_p.shape}")
    print(f"Target states: {target_state_space}")
    print(f"P matrix sum per row: {target_p.sum(axis=1)}")  
    print(f"Sample transition from state 0: {target_p[0]}")
    print(f"First 50 numeric sequence: {numeric_sequence[:50]}")
    print(f"Unique states in sequence: {set(numeric_sequence)}")
    print(f"Are all target states present? {set(target_state_space).issubset(set(numeric_sequence))}")

    if model is None:
        print("❌ model cannot be loaded")

    for idx, i in enumerate(sampled_positions):
        try:
            if i < len(is_target_position) and is_target_position[i]:
                current_effective_state = numeric_sequence[i]
            else:
                np.random.seed(42 + i)  
                current_effective_state = np.random.choice(target_state_space)

            prompt_exist = False
            if prompt_exist:
                input_prompt = _construct_input_prompt(full_series, numeric_sequence, i, 
                                                    is_target_position, target_state_space)
            else:
                input_prompt = _construct_input_no_prompt(full_series, numeric_sequence, i, 
                                                    is_target_position, target_state_space)
            
            learned_p_out = _get_model_prediction(model, tokenizer, input_prompt, target_tokens)

            state_idx = target_state_space.index(current_effective_state)
            true_p_out = torch.tensor(target_p[state_idx], dtype=torch.float32)

            true_p_out = true_p_out / torch.sum(true_p_out)
            learned_p_out = learned_p_out / torch.sum(learned_p_out)

            sqrt_PQ = torch.sum(torch.sqrt(learned_p_out * true_p_out))

            sqrt_PQ = torch.clamp(sqrt_PQ, min=1e-10)
            bhattacharyya_distance = -torch.log(sqrt_PQ)
            
            distance_array[idx] = bhattacharyya_distance.item()

        except Exception as e:
            if idx > 0:
                distance_array[idx] = distance_array[idx-1]
            else:
                distance_array[idx] = float('inf')  

    if smooth_window and len(distance_array) >= smooth_window:
        kernel = np.ones(smooth_window) / smooth_window
        smoothed_distances = np.convolve(distance_array, kernel, mode='valid')
        smoothed_positions = np.array(sampled_positions)[smooth_window//2:len(sampled_positions) - smooth_window//2 + 1]
    else:
        smoothed_distances = distance_array
        smoothed_positions = np.array(sampled_positions)
    sampled_chain_labels = [chain_labels[i] for i in sampled_positions if i < len(chain_labels)]
    sampled_is_target = is_target_position[sampled_positions]
    return {
        'distances': distance_array,                    
        'positions': np.array(sampled_positions),       
        'smoothed_distances': smoothed_distances,       
        'smoothed_positions': smoothed_positions,       
        'chain_labels': sampled_chain_labels,           
        'is_target_position': sampled_is_target,        
        'numeric_sequence': numeric_sequence,           
        'step_size': step_size,                         
        'smooth_window': smooth_window                  
    }


def _parse_full_series(full_series: str) -> Tuple[List[int], List[str]]:

    predict_match = re.search(r'Predict next:\s*(.+)$', full_series, re.DOTALL)
    if not predict_match:
        raise ValueError("cannot find 'Predict next:'")
    
    sequence_part = predict_match.group(1).strip()

    parts = re.split(r'(\[SWITCH_TO_INTERFERENCE\]|\[SWITCH_TO_NORMAL\])', sequence_part)
    
    numeric_sequence = []
    chain_labels = []
    current_chain = 'target'  
    
    for part in parts:
        part = part.strip()
        if part == '[SWITCH_TO_INTERFERENCE]':
            current_chain = 'interference'
        elif part == '[SWITCH_TO_NORMAL]':
            current_chain = 'target'
        elif part: 
            digits = re.findall(r'\d', part)
            for digit in digits:
                numeric_sequence.append(int(digit))
                chain_labels.append(current_chain)
    
    return numeric_sequence, chain_labels


def _construct_input_prompt(full_series: str, numeric_sequence: List[int], position: int,
                           is_target_position: np.ndarray, target_state_space: List[int]) -> str:

    np.random.seed(42 + position)

    predict_match = re.search(r'(.+?)Predict next:', full_series, re.DOTALL)
    if not predict_match:
        # task_description = "Learn pattern, predict next state. [SWITCH_TO_INTERFERENCE] = interference pattern starts. [SWITCH_TO_NORMAL] = target pattern resumes."
    
        task_description = """This is a number sequence prediction task with two patterns.
        [SWITCH_TO_INTERFERENCE] indicates interference pattern starts.
        [SWITCH_TO_NORMAL] indicates target pattern resumes.
        Your task: must predict exactly ONE single digit number that follows the sequence.
        Only predict the next digt number.

        Examples:
        0 1 2 -> 3
        1 3 0 -> 2
        0 1 [SWITCH_TO_INTERFERENCE] 2 3 [SWITCH_TO_NORMAL] 1 -> 0

        """
    
    else:
        task_description = predict_match.group(1).strip()

    predict_content = re.search(r'Predict next:\s*(.+)$', full_series, re.DOTALL).group(1).strip()
    parts = re.split(r'(\[SWITCH_TO_INTERFERENCE\]|\[SWITCH_TO_NORMAL\])', predict_content)
    
    reconstructed_sequence = ""
    digit_count = 0
    
    for part in parts:
        part = part.strip()
        if part in ['[SWITCH_TO_INTERFERENCE]', '[SWITCH_TO_NORMAL]']:
            reconstructed_sequence += part
        elif part:  
            digits = re.findall(r'\d', part)
            for digit in digits:
                if digit_count < position:
                    reconstructed_sequence += str(digit)
                    reconstructed_sequence += " "  
                    digit_count += 1
                elif digit_count == position:
                    if position < len(is_target_position) and not is_target_position[position]:
                        reconstructed_sequence += "[SWITCH_TO_NORMAL]"
                        reconstructed_sequence += " "  
                        temp_random_target = np.random.choice(target_state_space)
                        reconstructed_sequence += str(temp_random_target)
                        reconstructed_sequence += " "  
                    else:
                        reconstructed_sequence += str(digit)
                        reconstructed_sequence += " "  
                    digit_count += 1
                    break
                else:
                    break
            if digit_count > position:
                break
    return f"{task_description} Predict next: {reconstructed_sequence}"



def _construct_input_no_prompt(full_series: str, numeric_sequence: List[int], position: int,
                           is_target_position: np.ndarray, target_state_space: List[int]) -> str:

    np.random.seed(42 + position)

    predict_match = re.search(r'(.+?)Predict next:', full_series, re.DOTALL)
    if not predict_match:
        # task_description = "Learn pattern, predict next state. [SWITCH_TO_INTERFERENCE] = interference pattern starts. [SWITCH_TO_NORMAL] = target pattern resumes."
    
        task_description = """This is a number sequence prediction task with two patterns.
        [SWITCH_TO_INTERFERENCE] indicates interference pattern starts.
        [SWITCH_TO_NORMAL] indicates target pattern resumes.
        Your task: must predict exactly ONE single digit number that follows the sequence.
        Only predict the next digt number.

        Examples:
        0 1 2 -> 3
        1 3 0 -> 2
        0 1 [SWITCH_TO_INTERFERENCE] 2 3 [SWITCH_TO_NORMAL] 1 -> 0

        """
    
    else:
        task_description = predict_match.group(1).strip()

    predict_content = re.search(r'Predict next:\s*(.+)$', full_series, re.DOTALL).group(1).strip()
    parts = re.split(r'(\[SWITCH_TO_INTERFERENCE\]|\[SWITCH_TO_NORMAL\])', predict_content)
    
    reconstructed_sequence = ""
    digit_count = 0
    
    for part in parts:
        part = part.strip()
        if part in ['[SWITCH_TO_INTERFERENCE]', '[SWITCH_TO_NORMAL]']:
            reconstructed_sequence += " "
        elif part:  
            digits = re.findall(r'\d', part)
            for digit in digits:
                if digit_count < position:
                    reconstructed_sequence += str(digit)
                    reconstructed_sequence += " "  
                    digit_count += 1
                elif digit_count == position:
                    if position < len(is_target_position) and not is_target_position[position]:
                        temp_random_target = np.random.choice(target_state_space)
                        reconstructed_sequence += str(temp_random_target)
                        reconstructed_sequence += " " 
                    else:
                        reconstructed_sequence += str(digit)
                        reconstructed_sequence += " " 
                    digit_count += 1
                    break
                else:
                    break
            if digit_count > position:
                break
    return f"{reconstructed_sequence}"


def _get_model_prediction(model, tokenizer, input_prompt: str, target_tokens: List[int]) -> torch.Tensor:

    inputs = tokenizer(input_prompt, return_tensors="pt", add_special_tokens=True)

    device = next(model.parameters()).device
    input_ids = inputs['input_ids'].to(device)

    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    last_logits = logits[0, -1, :]  # [vocab_size]

    target_tokens_tensor = torch.tensor(target_tokens, device=device)
    
    target_logits = last_logits[target_tokens_tensor]  # [len(target_state_space)]

    target_probs = torch.softmax(target_logits, dim=-1)
    
    return target_probs.cpu()


