
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import sys
from typing import List, Tuple, Dict
import re
from pathlib import Path

class MarkovPredictor(nn.Module):
  
    def __init__(self, num_states=4, num_tasks=2, hidden_dim=64, use_embedding=True):
        super(MarkovPredictor, self).__init__()
        self.num_states = num_states
        self.num_tasks = num_tasks
        self.hidden_dim = hidden_dim
        self.use_embedding = use_embedding
        
        if use_embedding:
            self.state_embedding = nn.Embedding(num_states, hidden_dim // 2)
            self.task_embedding = nn.Embedding(num_tasks, hidden_dim // 2)
            input_dim = hidden_dim
        else:
            self.state_embedding = None
            self.task_embedding = None
            input_dim = num_states + num_tasks

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_states)
        )

        self._initialize_weights()
    
    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)
    
    def forward(self, states, task_ids):

        if len(states.shape) == 0:
            states = states.unsqueeze(0)
        if len(task_ids.shape) == 0:
            task_ids = task_ids.unsqueeze(0)
        
        if self.use_embedding:
            state_emb = self.state_embedding(states)
            task_emb = self.task_embedding(task_ids)
            combined = torch.cat([state_emb, task_emb], dim=-1)
        else:
            batch_size = states.size(0)
            state_onehot = torch.zeros(batch_size, self.num_states, device=states.device)
            task_onehot = torch.zeros(batch_size, self.num_tasks, device=task_ids.device)
            
            state_onehot.scatter_(1, states.unsqueeze(-1), 1)
            task_onehot.scatter_(1, task_ids.unsqueeze(-1), 1)
            
            combined = torch.cat([state_onehot, task_onehot], dim=1)

        logits = self.mlp(combined)
        return logits
    
    def predict_probs(self, state, task_id):

        self.eval()
        with torch.no_grad():
            device = next(self.parameters()).device
            state_tensor = torch.tensor([state], device=device)
            task_tensor = torch.tensor([task_id], device=device)
            
            logits = self.forward(state_tensor, task_tensor)
            probs = torch.softmax(logits, dim=1)
            return probs.squeeze(0)

def bhattacharyya_distance_corrected(pred_probs, current_state, true_next_state, 
                                   task_id, P_matrix, INTER_P_matrix):

    if task_id == 0:
        transition_matrix = P_matrix
    else:
        transition_matrix = INTER_P_matrix

    true_probs = torch.tensor(transition_matrix[current_state], 
                             dtype=pred_probs.dtype, 
                             device=pred_probs.device)

    bc = torch.sum(torch.sqrt(pred_probs * true_probs))

    distance = -torch.log(bc + 1e-10)
    
    return distance.item()

def load_prompt_files(generated_series_dir, continuous_series_names, markov_chain_names):

    continuous_series_task = {}
    markov_chain_task = {}
    
    generated_series_dir = Path(generated_series_dir)
    
    print("generated_series_dir:", generated_series_dir)
    
    # Loop through each file in the directory
    for file in generated_series_dir.iterdir():
        if file.suffix == '.pkl':
            series_name = '_'.join(file.stem.split('_')[:2])
            if series_name in continuous_series_names:
                continuous_series_task[file.name] = pickle.load(file.open('rb'))
                print("Continuous series:", file.name)
            elif series_name in markov_chain_names:
                markov_chain_task[file.name] = pickle.load(file.open('rb'))
                print("Markov chain:", file.name)
            else:
                print(f"Unrecognized series name: {series_name} for file: {file.name}")
    
    print("Continuous series keys:", list(continuous_series_task.keys()))
    print("Markov chain keys:", list(markov_chain_task.keys()))
    
    return continuous_series_task, markov_chain_task

def extract_transition_matrices(series_dict):

    P = series_dict['P']             
    INTER_P = series_dict['INTER_P']  
    states = series_dict['states']
    inter_states = series_dict['INTER_states']
    
    return P, INTER_P, states, inter_states

def parse_input_prompt(prompt_text):

    print("Parsing input prompt...")

    start_pattern = r'Predict next:\s*'
    match = re.search(start_pattern, prompt_text)
    if match:
        prompt_text = prompt_text[match.end():]
    
    states_sequence = []
    task_sequence = []
    switch_positions = []
    current_task = 0  
    i = 0
    while i < len(prompt_text):
        if prompt_text[i:].startswith('[SWITCH_TO_INTERFERENCE]'):
            if len(states_sequence) > 0: 
                switch_positions.append(len(states_sequence))
                print(f"Switch to interference mode, position: {len(states_sequence)}")
            current_task = 1
            i += len('[SWITCH_TO_INTERFERENCE]')
            
        elif prompt_text[i:].startswith('[SWITCH_TO_NORMAL]'):
            if len(states_sequence) > 0:  
                switch_positions.append(len(states_sequence))
                print(f"Switch to normal mode, position: {len(states_sequence)}")
            current_task = 0
            i += len('[SWITCH_TO_NORMAL]')
            
        elif prompt_text[i].isdigit():
            state = int(prompt_text[i])
            states_sequence.append(state)
            task_sequence.append(current_task)
            i += 1
            
        else:
            i += 1
    
    print(f"Parsing completed:")
    print(f"  Total states: {len(states_sequence)}")
    if states_sequence:
        print(f"  State value range: {min(states_sequence)} - {max(states_sequence)}")
    print(f"  Number of task switches: {len(switch_positions)}")
    print(f"  Switch positions: {switch_positions}")
    normal_count = sum(1 for t in task_sequence if t == 0)
    interference_count = sum(1 for t in task_sequence if t == 1)
    print(f"  Normal mode states: {normal_count}")
    print(f"  Interference mode states: {interference_count}")
    if switch_positions:
        print(f"  Task verification:")
        prev_pos = 0
        for i, switch_pos in enumerate(switch_positions[:3]): 
            task_before = task_sequence[switch_pos-1] if switch_pos > 0 else None
            task_after = task_sequence[switch_pos] if switch_pos < len(task_sequence) else None
            print(f"    Position {prev_pos}-{switch_pos-1}: Task {task_sequence[prev_pos] if prev_pos < len(task_sequence) else 'N/A'}")
            print(f"    Position {switch_pos}: Task changes to {task_after}")
            prev_pos = switch_pos
    
    return states_sequence, task_sequence, switch_positions

def load_input_prompt(pkl_file_path):

    print(f"Loading prompt file: {pkl_file_path}")
    
    with open(pkl_file_path, 'rb') as f:
        data = pickle.load(f)

    if isinstance(data, dict):
        possible_keys = ['prompt', 'text', 'sequence', 'content', 'data', 'full_series', 'full_series_with_switches']
        prompt_text = None
        
        for key in possible_keys:
            if key in data:
                prompt_text = data[key]
                break
        
        if prompt_text is None:
            print(f"Data dictionary keys: {list(data.keys())}")
            for value in data.values():
                if isinstance(value, str) and len(value) > 100:
                    prompt_text = value
                    break
    elif isinstance(data, str):
        prompt_text = data
    else:
        raise ValueError(f"Cannot extract prompt text from pkl file, data type: {type(data)}")
    
    if prompt_text is None:
        raise ValueError("Could not find valid prompt text in pkl file")
    
    print(f"Successfully loaded prompt, length: {len(prompt_text)} characters")
    return prompt_text



def generate_test_data(P_matrix, num_samples=10, seq_length=50):
    test_pairs = []
    for _ in range(num_samples):
        sequence = generate_markov_sequence(P_matrix, seq_length)
        for i in range(len(sequence) - 1):
            current_state = sequence[i]
            next_state = sequence[i + 1]
            test_pairs.append((current_state, next_state))
    return test_pairs

def generate_markov_sequence(transition_matrix, length):
    N_state = transition_matrix.shape[0]
    current_state = np.random.randint(0, N_state)
    sequence = [current_state]
    
    for _ in range(length - 1):
        probs = transition_matrix[current_state]
        next_state = np.random.choice(N_state, p=probs)
        sequence.append(next_state)
        current_state = next_state
    
    return sequence




def run_forgetting_experiment_with_online_training(states_sequence, task_sequence, switch_positions, 
                                                  P_matrix, INTER_P_matrix, learning_rate=0.001, 
                                                  num_states=4, num_tasks=2, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"\nStarting SGD forgetting experiment with online training...")
    print(f"Device: {device}")
    print(f"Sequence length: {len(states_sequence)}")
    print(f"Learning rate: {learning_rate}")
    print(f"Evaluation metric: Bhattacharyya distance")

    torch.manual_seed(42)  
    model = MarkovPredictor(
        num_states=num_states,
        num_tasks=num_tasks,
        hidden_dim=64,
        use_embedding=True
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    P_tensor = torch.tensor(P_matrix, dtype=torch.float32, device=device)
    INTER_P_tensor = torch.tensor(INTER_P_matrix, dtype=torch.float32, device=device)

    distances = []
    predictions = []
    true_next_states = []
    positions = []
    training_losses = []
    
    print(f"\nStarting step-by-step prediction and training...")
    
    for i in range(len(states_sequence) - 1):
        current_state = states_sequence[i]
        current_task = task_sequence[i]
        true_next_state = states_sequence[i + 1]

        model.train()
        optimizer.zero_grad()

        state_tensor = torch.tensor([current_state], device=device)
        task_tensor = torch.tensor([current_task], device=device)
        target_tensor = torch.tensor([true_next_state], device=device)
        
        logits = model(state_tensor, task_tensor)
        loss = criterion(logits, target_tensor)

        loss.backward()
        optimizer.step()

        training_losses.append(loss.item())
        
        test_data = generate_test_data(P_matrix, num_samples=1, seq_length=50)
        total_distance = 0.0
        for current_state_test, true_next_state_test in test_data:
            probs = model.predict_probs(current_state_test, task_id=0)
            distance = bhattacharyya_distance_corrected(
                probs, current_state_test, true_next_state_test, 
                0, P_tensor, INTER_P_tensor)
            total_distance += distance
        distance = total_distance / len(test_data)
        distances.append(distance)
        true_next_states.append(true_next_state)
        positions.append(i)
        if (i + 1) % 100 == 0 or i == 0:
            print(f"Position {i+1:5d}: State{current_state}(Task{current_task}) -> "
                  f"True{true_next_state}, "
                  f"Distance{distance:.4f}, Loss{loss.item():.4f}")
    
    print(f"Training and prediction completed!")
    window_size = 100
    smoothed_distances = []
    
    for i in range(len(distances)):
        start_idx = max(0, i - window_size + 1)
        end_idx = i + 1
        window_dist = np.mean(distances[start_idx:end_idx])
        smoothed_distances.append(window_dist)
    switch_analysis = []
    for switch_pos in switch_positions:
        if switch_pos < len(distances):
            before_window = 50
            after_window = 50
            
            before_start = max(0, switch_pos - before_window)
            before_end = switch_pos
            after_start = switch_pos
            after_end = min(len(distances), switch_pos + after_window)
            
            if before_end > before_start and after_end > after_start:
                dist_before = np.mean(distances[before_start:before_end])
                dist_after = np.mean(distances[after_start:after_end])
                loss_before = np.mean(training_losses[before_start:before_end])
                loss_after = np.mean(training_losses[after_start:after_end])
                
                switch_analysis.append({
                    'position': switch_pos,
                    'task_before': task_sequence[max(0, switch_pos-1)],
                    'task_after': task_sequence[min(len(task_sequence)-1, switch_pos)],
                    'distance_before': dist_before,
                    'distance_after': dist_after,
                    'distance_change': dist_after - dist_before,
                    'loss_before': loss_before,
                    'loss_after': loss_after,
                    'loss_change': loss_after - loss_before
                })
    
    overall_distance = np.mean(distances)
    overall_loss = np.mean(training_losses)
    result = {
        'positions': positions,
        'distances': distances
    }
    detailed_results = {
        'positions': positions,
        'distances': distances,
        'smoothed_distances': smoothed_distances,
        'true_next_states': true_next_states,
        'switch_positions': switch_positions,
        'switch_analysis': switch_analysis,
        'overall_distance': overall_distance,
        'training_losses': training_losses,
        'overall_loss': overall_loss,
        'learning_rate': learning_rate,
        'states_sequence': states_sequence,
        'task_sequence': task_sequence,
        'model': model,
        'num_states': num_states,
        'num_tasks': num_tasks,
        'P_matrix': P_matrix,
        'INTER_P_matrix': INTER_P_matrix
    }
    
    print(f"\nExperiment statistics:")
    print(f"  Total predictions: {len(distances):,}")
    print(f"  Average Bhattacharyya distance: {overall_distance:.4f}")
    print(f"  Average training loss: {overall_loss:.4f}")
    print(f"  Distance std: {np.std(distances):.4f}")
    print(f"  Min distance: {min(distances):.4f}")
    print(f"  Max distance: {max(distances):.4f}")
    
    return result, detailed_results

def main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path=None):

    continuous_series_task, markov_chain_task = load_prompt_files(
        generated_series_dir, continuous_series_names, markov_chain_names
    )
    
    results_collection = {}

    for series_name, series_dict in sorted(markov_chain_task.items()):
        print(f"\nProcessing {series_name}")
        
        try:

            full_series = series_dict.get('full_series_with_switches', 
                                         series_dict.get('full_series', ''))
            
            if not full_series:
                print(f"Warning: No full_series found in {series_name}")
                continue
                
            P, INTER_P, states, inter_states = extract_transition_matrices(series_dict)
            
            print(f"Full series length: {len(full_series)}")
            print(f"P matrix shape: {np.array(P).shape}")
            print(f"INTER_P matrix shape: {np.array(INTER_P).shape}")

            states_sequence, task_sequence, switch_positions = parse_input_prompt(full_series)
            
            if not states_sequence:
                print(f"Warning: No states sequence parsed from {series_name}")
                continue

            results, detailed_results = run_forgetting_experiment_with_online_training(
                states_sequence, task_sequence, switch_positions,
                P, INTER_P, learning_rate=0.001, num_states=len(states), num_tasks=2
            )

            results_collection[series_name] = results
            data_name = Path(series_name).stem
            pkl_path = os.path.join(save_path, f"{data_name}_results.pkl")
            with open(pkl_path, "wb") as f:
                pickle.dump(results, f)

            positions = results['positions']
            distances = results['distances']
            training_losses = detailed_results['training_losses']

            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

            ax1.plot(positions, distances, 'b-', linewidth=1, alpha=0.7, label='Bhattacharyya Distance')

            for switch_pos in switch_positions:
                ax1.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax1.set_xlabel("Positions")
            ax1.set_ylabel("Distances")
            ax1.set_yscale('log')
            ax1.set_title("Distances vs. Positions (Log Scale)")
            ax1.grid(True, which="both", ls="--", linewidth=0.5)
            ax1.legend()

            ax2.plot(positions, training_losses, 'g-', linewidth=1, alpha=0.7, label='Training Loss')

            for switch_pos in switch_positions:
                ax2.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax2.set_xlabel("Positions")
            ax2.set_ylabel("Training Loss")
            ax2.set_title("Training Loss Over Time")
            ax2.grid(True, which="both", ls="--", linewidth=0.5)
            ax2.legend()

            plt.tight_layout()
            fig_path = os.path.join(save_path, f"{data_name}.png")
            plt.savefig(fig_path, dpi=300)
            plt.close()

            print(f"Completed processing {series_name}")
            
        except Exception as e:
            print(f"Error processing {series_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\n{'='*80}")
    print(f"Directory processing completed!")
    print(f"Successfully processed {len(results_collection)} files")
    print(f"{'='*80}")
    
    return results_collection

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(parent_dir)
    save_path = Path(parent_dir) / 'xxx'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    generated_series_dir = Path(parent_dir) / 'xxx'
    continuous_series_names = ["continuous_series"]  
    markov_chain_names = ["markov_chain"]  
    results_collection = main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path)
    
    print("\nExperiment completed!")