import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import sys
from typing import List, Tuple, Dict
import re
from pathlib import Path
import random
from collections import deque

class ExperienceReplayBuffer:
    
    def __init__(self, buffer_size=8000, device=None):

        self.buffer_size = buffer_size
        self.device = device if device is not None else torch.device('cpu')
        self.buffer = deque(maxlen=buffer_size)
        
    def store(self, state, task_id, next_state):

        experience = {
            'state': state,
            'task_id': task_id,
            'next_state': next_state
        }
        self.buffer.append(experience)
    
    def sample(self, batch_size):

        if len(self.buffer) < batch_size:
            batch = list(self.buffer)
        else:
            # Random sampling
            batch = random.sample(list(self.buffer), batch_size)
        
        if not batch:
            return None, None, None
            
        states = torch.tensor([exp['state'] for exp in batch], 
                             dtype=torch.long, device=self.device)
        task_ids = torch.tensor([exp['task_id'] for exp in batch], 
                               dtype=torch.long, device=self.device)
        next_states = torch.tensor([exp['next_state'] for exp in batch], 
                                  dtype=torch.long, device=self.device)
        
        return states, task_ids, next_states
    
    def size(self):
        """Return current buffer size"""
        return len(self.buffer)
    
    def is_empty(self):
        """Check if buffer is empty"""
        return len(self.buffer) == 0

class MarkovPredictor(nn.Module):
    
    def __init__(self, num_states=4, num_tasks=2, hidden_dim=64, use_embedding=True):
        super(MarkovPredictor, self).__init__()
        self.num_states = num_states
        self.num_tasks = num_tasks
        self.hidden_dim = hidden_dim
        self.use_embedding = use_embedding
        
        if use_embedding:
            self.state_embedding = nn.Embedding(num_states, hidden_dim // 2)
            self.task_embedding = nn.Embedding(num_tasks, hidden_dim // 2)
            input_dim = hidden_dim
        else:
            self.state_embedding = None
            self.task_embedding = None
            input_dim = num_states + num_tasks
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_states)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)
    
    def forward(self, states, task_ids):

        if len(states.shape) == 0:
            states = states.unsqueeze(0)
        if len(task_ids.shape) == 0:
            task_ids = task_ids.unsqueeze(0)
        
        if self.use_embedding:
            state_emb = self.state_embedding(states)
            task_emb = self.task_embedding(task_ids)
            combined = torch.cat([state_emb, task_emb], dim=-1)
        else:
            batch_size = states.size(0)
            state_onehot = torch.zeros(batch_size, self.num_states, device=states.device)
            task_onehot = torch.zeros(batch_size, self.num_tasks, device=task_ids.device)
            
            state_onehot.scatter_(1, states.unsqueeze(-1), 1)
            task_onehot.scatter_(1, task_ids.unsqueeze(-1), 1)
            
            combined = torch.cat([state_onehot, task_onehot], dim=1)
        
        logits = self.mlp(combined)
        return logits
    
    def predict_probs(self, state, task_id):
        self.eval()
        with torch.no_grad():
            device = next(self.parameters()).device
            state_tensor = torch.tensor([state], device=device)
            task_tensor = torch.tensor([task_id], device=device)
            
            logits = self.forward(state_tensor, task_tensor)
            probs = torch.softmax(logits, dim=1)
            return probs.squeeze(0)

def bhattacharyya_distance_corrected(pred_probs, current_state, true_next_state, 
                                   task_id, P_matrix, INTER_P_matrix):

    if task_id == 0:
        transition_matrix = P_matrix
    else:
        transition_matrix = INTER_P_matrix
    
    true_probs = torch.tensor(transition_matrix[current_state], 
                             dtype=pred_probs.dtype, 
                             device=pred_probs.device)
    bc = torch.sum(torch.sqrt(pred_probs * true_probs))
    distance = -torch.log(bc + 1e-10)
    
    return distance.item()

def load_prompt_files(generated_series_dir, continuous_series_names, markov_chain_names):

    continuous_series_task = {}
    markov_chain_task = {}
    
    generated_series_dir = Path(generated_series_dir)
    
    print("generated_series_dir:", generated_series_dir)
    
    # Loop through each file in the directory
    for file in generated_series_dir.iterdir():
        if file.suffix == '.pkl':
            # Extract the series name from the file name
            series_name = '_'.join(file.stem.split('_')[:2])
            
            # If the series is a continuous series, load the data into the continuous_series_data dictionary
            if series_name in continuous_series_names:
                continuous_series_task[file.name] = pickle.load(file.open('rb'))
                print("Continuous series:", file.name)
            # If the series is a Markov chain, load the data into the markov_chain_data dictionary
            elif series_name in markov_chain_names:
                markov_chain_task[file.name] = pickle.load(file.open('rb'))
                print("Markov chain:", file.name)
            else:
                print(f"Unrecognized series name: {series_name} for file: {file.name}")
    
    print("Continuous series keys:", list(continuous_series_task.keys()))
    print("Markov chain keys:", list(markov_chain_task.keys()))
    
    return continuous_series_task, markov_chain_task

def extract_transition_matrices(series_dict):

    P = series_dict['P']             
    INTER_P = series_dict['INTER_P']  
    states = series_dict['states']
    inter_states = series_dict['INTER_states']
    
    return P, INTER_P, states, inter_states

def parse_input_prompt(prompt_text):

    start_pattern = r'Predict next:\s*'
    match = re.search(start_pattern, prompt_text)
    if match:
        prompt_text = prompt_text[match.end():]
    
    states_sequence = []
    task_sequence = []
    switch_positions = []
    current_task = 0 
    
    i = 0
    while i < len(prompt_text):
        if prompt_text[i:].startswith('[SWITCH_TO_INTERFERENCE]'):
            if len(states_sequence) > 0:  
                switch_positions.append(len(states_sequence))
                print(f"Switch to interference mode, position: {len(states_sequence)}")
            current_task = 1
            i += len('[SWITCH_TO_INTERFERENCE]')
            
        elif prompt_text[i:].startswith('[SWITCH_TO_NORMAL]'):
            if len(states_sequence) > 0: 
                switch_positions.append(len(states_sequence))
                print(f"Switch to normal mode, position: {len(states_sequence)}")
            current_task = 0
            i += len('[SWITCH_TO_NORMAL]')
            
        elif prompt_text[i].isdigit():
            state = int(prompt_text[i])
            states_sequence.append(state)
            task_sequence.append(current_task)
            i += 1
            
        else:
            i += 1
    
    print(f"Parsing completed:")
    print(f"  Total states: {len(states_sequence)}")
    if states_sequence:
        print(f"  State value range: {min(states_sequence)} - {max(states_sequence)}")
    print(f"  Number of task switches: {len(switch_positions)}")
    print(f"  Switch positions: {switch_positions}")
    
    normal_count = sum(1 for t in task_sequence if t == 0)
    interference_count = sum(1 for t in task_sequence if t == 1)
    print(f"  Normal mode states: {normal_count}")
    print(f"  Interference mode states: {interference_count}")
    
    if switch_positions:
        print(f"  Task verification:")
        prev_pos = 0
        for i, switch_pos in enumerate(switch_positions[:3]):  
            task_before = task_sequence[switch_pos-1] if switch_pos > 0 else None
            task_after = task_sequence[switch_pos] if switch_pos < len(task_sequence) else None
            print(f"    Position {prev_pos}-{switch_pos-1}: Task {task_sequence[prev_pos] if prev_pos < len(task_sequence) else 'N/A'}")
            print(f"    Position {switch_pos}: Task changes to {task_after}")
            prev_pos = switch_pos
    
    return states_sequence, task_sequence, switch_positions

def load_input_prompt(pkl_file_path):

    print(f"Loading prompt file: {pkl_file_path}")
    
    with open(pkl_file_path, 'rb') as f:
        data = pickle.load(f)

    if isinstance(data, dict):
        possible_keys = ['prompt', 'text', 'sequence', 'content', 'data', 'full_series', 'full_series_with_switches']
        prompt_text = None
        
        for key in possible_keys:
            if key in data:
                prompt_text = data[key]
                break
        
        if prompt_text is None:
            print(f"Data dictionary keys: {list(data.keys())}")

            for value in data.values():
                if isinstance(value, str) and len(value) > 100:
                    prompt_text = value
                    break
    elif isinstance(data, str):
        prompt_text = data
    else:
        raise ValueError(f"Cannot extract prompt text from pkl file, data type: {type(data)}")
    
    if prompt_text is None:
        raise ValueError("Could not find valid prompt text in pkl file")
    
    print(f"Successfully loaded prompt, length: {len(prompt_text)} characters")
    return prompt_text

def generate_test_data(P_matrix, num_samples=10, seq_length=50):
    test_pairs = []
    for _ in range(num_samples):
        sequence = generate_markov_sequence(P_matrix, seq_length)
        for i in range(len(sequence) - 1):
            current_state = sequence[i]
            next_state = sequence[i + 1]
            test_pairs.append((current_state, next_state))
    return test_pairs

def generate_markov_sequence(transition_matrix, length):
    N_state = transition_matrix.shape[0]
    current_state = np.random.randint(0, N_state)
    sequence = [current_state]
    
    for _ in range(length - 1):
        probs = transition_matrix[current_state]
        next_state = np.random.choice(N_state, p=probs)
        sequence.append(next_state)
        current_state = next_state
    
    return sequence

def run_forgetting_experiment_with_online_training(states_sequence, task_sequence, switch_positions, 
                                                  P_matrix, INTER_P_matrix, learning_rate=0.001, 
                                                  num_states=4, num_tasks=2, device=None,
                                                  replay_buffer_size=2000, replay_ratio=0.5, batch_size=32):

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"\nStarting ER forgetting experiment with online training...")
    print(f"Device: {device}")
    print(f"Sequence length: {len(states_sequence)}")
    print(f"Learning rate: {learning_rate}")
    print(f"Replay buffer size: {replay_buffer_size}")
    print(f"Replay ratio: {replay_ratio}")
    print(f"Batch size: {batch_size}")
    print(f"Evaluation metric: Bhattacharyya distance")
    
    torch.manual_seed(42) 
    model = MarkovPredictor(
        num_states=num_states,
        num_tasks=num_tasks,
        hidden_dim=64,
        use_embedding=True
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    replay_buffer = ExperienceReplayBuffer(buffer_size=replay_buffer_size, device=device)
    P_tensor = torch.tensor(P_matrix, dtype=torch.float32, device=device)
    INTER_P_tensor = torch.tensor(INTER_P_matrix, dtype=torch.float32, device=device)
    distances = []
    predictions = []
    true_next_states = []
    positions = []
    training_losses = []
    
    print(f"\nStarting step-by-step prediction and training...")
    
    for i in range(len(states_sequence) - 1):
        current_state = states_sequence[i]
        current_task = task_sequence[i]
        true_next_state = states_sequence[i + 1]

        replay_buffer.store(current_state, current_task, true_next_state)

        model.train()
        optimizer.zero_grad()
        current_size = 1
        replay_size = min(int(batch_size * replay_ratio), replay_buffer.size())

        current_states = torch.tensor([current_state], device=device)
        current_tasks = torch.tensor([current_task], device=device)
        current_targets = torch.tensor([true_next_state], device=device)
        if replay_size > 0 and not replay_buffer.is_empty():
            replay_states, replay_task_ids, replay_targets = replay_buffer.sample(replay_size)
            batch_states = torch.cat([current_states, replay_states], dim=0)
            batch_tasks = torch.cat([current_tasks, replay_task_ids], dim=0)
            batch_targets = torch.cat([current_targets, replay_targets], dim=0)
        else:
            batch_states = current_states
            batch_tasks = current_tasks
            batch_targets = current_targets

        logits = model(batch_states, batch_tasks)
        loss = criterion(logits, batch_targets)
        
        loss.backward()
        optimizer.step()

        training_losses.append(loss.item())
        
        test_data = generate_test_data(P_matrix, num_samples=1, seq_length=50)
        total_distance = 0.0
        for current_state_test, true_next_state_test in test_data:
            probs = model.predict_probs(current_state_test, task_id=0)
            distance = bhattacharyya_distance_corrected(
                probs, current_state_test, true_next_state_test, 
                0, P_tensor, INTER_P_tensor)
            total_distance += distance
        distance = total_distance / len(test_data)
        distances.append(distance)
        true_next_states.append(true_next_state)
        positions.append(i)
        if (i + 1) % 100 == 0 or i == 0:
            print(f"Position {i+1:5d}: State{current_state}(Task{current_task}) -> "
                  f"True{true_next_state}, "
                  f"Distance{distance:.4f}, Loss{loss.item():.4f}, "
                  f"Buffer Size{replay_buffer.size()}")
    
    print(f"Training and prediction completed!")

    window_size = 100
    smoothed_distances = []
    
    for i in range(len(distances)):
        start_idx = max(0, i - window_size + 1)
        end_idx = i + 1
        window_dist = np.mean(distances[start_idx:end_idx])
        smoothed_distances.append(window_dist)
    switch_analysis = []
    for switch_pos in switch_positions:
        if switch_pos < len(distances):
            before_window = 50
            after_window = 50
            
            before_start = max(0, switch_pos - before_window)
            before_end = switch_pos
            after_start = switch_pos
            after_end = min(len(distances), switch_pos + after_window)
            
            if before_end > before_start and after_end > after_start:
                dist_before = np.mean(distances[before_start:before_end])
                dist_after = np.mean(distances[after_start:after_end])
                loss_before = np.mean(training_losses[before_start:before_end])
                loss_after = np.mean(training_losses[after_start:after_end])
                
                switch_analysis.append({
                    'position': switch_pos,
                    'task_before': task_sequence[max(0, switch_pos-1)],
                    'task_after': task_sequence[min(len(task_sequence)-1, switch_pos)],
                    'distance_before': dist_before,
                    'distance_after': dist_after,
                    'distance_change': dist_after - dist_before,
                    'loss_before': loss_before,
                    'loss_after': loss_after,
                    'loss_change': loss_after - loss_before
                })

    overall_distance = np.mean(distances)
    overall_loss = np.mean(training_losses)
    result = {
        'positions': positions,
        'distances': distances
    }
    detailed_results = {
        'positions': positions,
        'distances': distances,
        'smoothed_distances': smoothed_distances,
        'true_next_states': true_next_states,
        'switch_positions': switch_positions,
        'switch_analysis': switch_analysis,
        'overall_distance': overall_distance,
        'training_losses': training_losses,
        'overall_loss': overall_loss,
        'learning_rate': learning_rate,
        'replay_buffer_size': replay_buffer_size,
        'replay_ratio': replay_ratio,
        'batch_size': batch_size,
        'states_sequence': states_sequence,
        'task_sequence': task_sequence,
        'model': model,
        'num_states': num_states,
        'num_tasks': num_tasks,
        'P_matrix': P_matrix,
        'INTER_P_matrix': INTER_P_matrix,
        'replay_buffer': replay_buffer
    }
    
    print(f"\nExperiment statistics:")
    print(f"  Total predictions: {len(distances):,}")
    print(f"  Average Bhattacharyya distance: {overall_distance:.4f}")
    print(f"  Average training loss: {overall_loss:.4f}")
    print(f"  Distance std: {np.std(distances):.4f}")
    print(f"  Min distance: {min(distances):.4f}")
    print(f"  Max distance: {max(distances):.4f}")
    print(f"  Final buffer size: {replay_buffer.size()}")
    
    return result, detailed_results

def main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path=None):

    continuous_series_task, markov_chain_task = load_prompt_files(
        generated_series_dir, continuous_series_names, markov_chain_names
    )
    
    results_collection = {}

    for series_name, series_dict in sorted(markov_chain_task.items()):
        print(f"\nProcessing {series_name}")
        
        try:
            full_series = series_dict.get('full_series_with_switches', 
                                         series_dict.get('full_series', ''))
            
            if not full_series:
                print(f"Warning: No full_series found in {series_name}")
                continue
                
            P, INTER_P, states, inter_states = extract_transition_matrices(series_dict)
            
            print(f"Full series length: {len(full_series)}")
            print(f"P matrix shape: {np.array(P).shape}")
            print(f"INTER_P matrix shape: {np.array(INTER_P).shape}")
            states_sequence, task_sequence, switch_positions = parse_input_prompt(full_series)
            
            if not states_sequence:
                print(f"Warning: No states sequence parsed from {series_name}")
                continue

            results, detailed_results = run_forgetting_experiment_with_online_training(
                states_sequence, task_sequence, switch_positions,
                P, INTER_P, learning_rate=0.001, num_states=len(states), num_tasks=2,
                replay_buffer_size=8000, replay_ratio=0.5, batch_size=32
            )

            results_collection[series_name] = results
            data_name = Path(series_name).stem
            pkl_path = os.path.join(save_path, f"{data_name}_results.pkl")
            with open(pkl_path, "wb") as f:
                pickle.dump(results, f)

            positions = results['positions']
            distances = results['distances']
            training_losses = detailed_results['training_losses']

            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

            ax1.plot(positions, distances, 'b-', linewidth=1, alpha=0.7, label='Bhattacharyya Distance')

            for switch_pos in switch_positions:
                ax1.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax1.set_xlabel("Positions")
            ax1.set_ylabel("Distances")
            ax1.set_yscale('log')
            ax1.set_title("Distances vs. Positions (Log Scale) - ER Method")
            ax1.grid(True, which="both", ls="--", linewidth=0.5)
            ax1.legend()

            ax2.plot(positions, training_losses, 'g-', linewidth=1, alpha=0.7, label='Training Loss')

            for switch_pos in switch_positions:
                ax2.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax2.set_xlabel("Positions")
            ax2.set_ylabel("Training Loss")
            ax2.set_title("Training Loss Over Time - ER Method")
            ax2.grid(True, which="both", ls="--", linewidth=0.5)
            ax2.legend()

            plt.tight_layout()
            fig_path = os.path.join(save_path, f"{data_name}.png")
            plt.savefig(fig_path, dpi=300)
            plt.close()

            print(f"Completed processing {series_name}")
            
        except Exception as e:
            print(f"Error processing {series_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\n{'='*80}")
    print(f"Directory processing completed!")
    print(f"Successfully processed {len(results_collection)} files")
    print(f"{'='*80}")
    
    return results_collection

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)  
    
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(parent_dir)
    save_path = Path(parent_dir) / 'icl_run_results_er_4_sp'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    generated_series_dir = Path(parent_dir) / 'data_gen_600' / 'generated_series_er_4_sp'
    continuous_series_names = ["continuous_series"]  
    markov_chain_names = ["markov_chain"]  
    results_collection = main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path)
    
    print("\nExperiment completed!")