import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import sys
from typing import List, Tuple, Dict
import re
from pathlib import Path
import copy

class FisherInformationCalculator:

    
    def __init__(self, model, device=None):
        self.model = model
        self.device = device if device is not None else torch.device('cpu')
        
    def compute_fisher_information(self, data_loader, criterion, num_samples=None):

        self.model.eval()
        fisher_dict = {}

        for name, param in self.model.named_parameters():
            if param.requires_grad:
                fisher_dict[name] = torch.zeros_like(param.data)
        
        if not data_loader or len(data_loader) == 0:
            print("Warning: Empty data_loader for Fisher information computation")
            return fisher_dict
        
        sample_count = 0
        max_samples = num_samples if num_samples is not None else len(data_loader)
        
        print(f"Computing Fisher information using {min(max_samples, len(data_loader))} samples...")
        
        for i, (states, task_ids, targets) in enumerate(data_loader):
            if sample_count >= max_samples:
                break
            
            if states.numel() == 0 or task_ids.numel() == 0 or targets.numel() == 0:
                continue
                
            if len(states.shape) == 0:
                states = states.unsqueeze(0)
            if len(task_ids.shape) == 0:
                task_ids = task_ids.unsqueeze(0)
            if len(targets.shape) == 0:
                targets = targets.unsqueeze(0)
                
            states = states.to(self.device)
            task_ids = task_ids.to(self.device)
            targets = targets.to(self.device)
            
            self.model.zero_grad()
            logits = self.model(states, task_ids)
            loss = criterion(logits, targets)
            
            loss.backward()
            
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    fisher_dict[name] += param.grad.data ** 2
            
            sample_count += 1
        
        if sample_count > 0:
            for name in fisher_dict:
                fisher_dict[name] = fisher_dict[name] / sample_count
        else:
            print("Warning: No valid samples for Fisher information computation")

        self.model.zero_grad()
        
        print(f"Fisher information computed using {sample_count} samples")
        return fisher_dict

class EWCRegularizer:
    
    def __init__(self, model, lambda_ewc=1000.0, device=None):
        self.model = model
        self.lambda_ewc = lambda_ewc
        self.device = device if device is not None else torch.device('cpu')
        
        self.task_params = {} 
        self.task_fisher = {}  
        self.saved_task_types = set()  
        
    def save_task_params_by_type(self, fisher_dict, task_type):
        if task_type in self.saved_task_types:
            print(f"Task type {task_type} already exists, updating Fisher information...")
            action = "Updated"
        else:
            print(f"New task type {task_type}, saving Fisher information...")
            self.saved_task_types.add(task_type)
            action = "Saved"

        self.task_params[task_type] = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.task_params[task_type][name] = param.data.clone()
        
        self.task_fisher[task_type] = {}
        for name in fisher_dict:
            self.task_fisher[task_type][name] = fisher_dict[name].clone()
        
        print(f"{action} task type {task_type}. Total unique task types: {len(self.saved_task_types)}")
    
    def compute_ewc_loss(self):
        ewc_loss = 0.0
        
        for task_type in self.saved_task_types:
            for name, param in self.model.named_parameters():
                if (param.requires_grad and 
                    task_type in self.task_params and 
                    name in self.task_params[task_type] and
                    task_type in self.task_fisher and
                    name in self.task_fisher[task_type]):
                    
                    old_param = self.task_params[task_type][name]
                    fisher = self.task_fisher[task_type][name]
                    
                    ewc_loss += (fisher * (param - old_param) ** 2).sum()
        
        return self.lambda_ewc / 2.0 * ewc_loss
    
    def get_num_saved_tasks(self):
        return len(self.saved_task_types)
    
    def get_saved_task_types(self):
        return sorted(list(self.saved_task_types))

class MarkovPredictor(nn.Module):
    
    def __init__(self, num_states=4, num_tasks=2, hidden_dim=64, use_embedding=True):
        super(MarkovPredictor, self).__init__()
        self.num_states = num_states
        self.num_tasks = num_tasks
        self.hidden_dim = hidden_dim
        self.use_embedding = use_embedding
        
        if use_embedding:
            self.state_embedding = nn.Embedding(num_states, hidden_dim // 2)
            self.task_embedding = nn.Embedding(num_tasks, hidden_dim // 2)
            input_dim = hidden_dim
        else:
            self.state_embedding = None
            self.task_embedding = None
            input_dim = num_states + num_tasks
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_states)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)
    
    def forward(self, states, task_ids):
        if len(states.shape) == 0:
            states = states.unsqueeze(0)
        if len(task_ids.shape) == 0:
            task_ids = task_ids.unsqueeze(0)
        
        if self.use_embedding:
            state_emb = self.state_embedding(states)
            task_emb = self.task_embedding(task_ids)
            combined = torch.cat([state_emb, task_emb], dim=-1)
        else:
            batch_size = states.size(0)
            state_onehot = torch.zeros(batch_size, self.num_states, device=states.device)
            task_onehot = torch.zeros(batch_size, self.num_tasks, device=task_ids.device)
            
            state_onehot.scatter_(1, states.unsqueeze(-1), 1)
            task_onehot.scatter_(1, task_ids.unsqueeze(-1), 1)
            
            combined = torch.cat([state_onehot, task_onehot], dim=1)
        
        logits = self.mlp(combined)
        return logits
    
    def predict_probs(self, state, task_id):
        self.eval()
        with torch.no_grad():
            device = next(self.parameters()).device
            state_tensor = torch.tensor([state], device=device)
            task_tensor = torch.tensor([task_id], device=device)
            
            logits = self.forward(state_tensor, task_tensor)
            probs = torch.softmax(logits, dim=1)
            return probs.squeeze(0)

def bhattacharyya_distance_corrected(pred_probs, current_state, true_next_state, 
                                   task_id, P_matrix, INTER_P_matrix):
    if task_id == 0:
        transition_matrix = P_matrix
    else:
        transition_matrix = INTER_P_matrix

    true_probs = torch.tensor(transition_matrix[current_state], 
                             dtype=pred_probs.dtype, 
                             device=pred_probs.device)

    bc = torch.sum(torch.sqrt(pred_probs * true_probs))

    distance = -torch.log(bc + 1e-10)
    
    return distance.item()

def load_prompt_files(generated_series_dir, continuous_series_names, markov_chain_names):

    continuous_series_task = {}
    markov_chain_task = {}
    
    generated_series_dir = Path(generated_series_dir)
    
    print("generated_series_dir:", generated_series_dir)
    
    # Loop through each file in the directory
    for file in generated_series_dir.iterdir():
        if file.suffix == '.pkl':
            series_name = '_'.join(file.stem.split('_')[:2])
            if series_name in continuous_series_names:
                continuous_series_task[file.name] = pickle.load(file.open('rb'))
                print("Continuous series:", file.name)
            elif series_name in markov_chain_names:
                markov_chain_task[file.name] = pickle.load(file.open('rb'))
                print("Markov chain:", file.name)
            else:
                print(f"Unrecognized series name: {series_name} for file: {file.name}")
    
    print("Continuous series keys:", list(continuous_series_task.keys()))
    print("Markov chain keys:", list(markov_chain_task.keys()))
    
    return continuous_series_task, markov_chain_task

def extract_transition_matrices(series_dict):
    P = series_dict['P']              
    INTER_P = series_dict['INTER_P'] 
    states = series_dict['states']
    inter_states = series_dict['INTER_states']
    
    return P, INTER_P, states, inter_states

def parse_input_prompt(prompt_text):

    start_pattern = r'Predict next:\s*'
    match = re.search(start_pattern, prompt_text)
    if match:
        prompt_text = prompt_text[match.end():]
    
    states_sequence = []
    task_sequence = []
    switch_positions = []
    current_task = 0  
    i = 0
    while i < len(prompt_text):
        if prompt_text[i:].startswith('[SWITCH_TO_INTERFERENCE]'):
            if len(states_sequence) > 0:  
                switch_positions.append(len(states_sequence))
                print(f"Switch to interference mode, position: {len(states_sequence)}")
            current_task = 1
            i += len('[SWITCH_TO_INTERFERENCE]')
            
        elif prompt_text[i:].startswith('[SWITCH_TO_NORMAL]'):
            if len(states_sequence) > 0:  
                switch_positions.append(len(states_sequence))
                print(f"Switch to normal mode, position: {len(states_sequence)}")
            current_task = 0
            i += len('[SWITCH_TO_NORMAL]')
            
        elif prompt_text[i].isdigit():
            state = int(prompt_text[i])
            states_sequence.append(state)
            task_sequence.append(current_task)
            i += 1
            
        else:
            i += 1
    
    print(f"Parsing completed:")
    print(f"  Total states: {len(states_sequence)}")
    if states_sequence:
        print(f"  State value range: {min(states_sequence)} - {max(states_sequence)}")
    print(f"  Number of task switches: {len(switch_positions)}")
    print(f"  Switch positions: {switch_positions}")

    normal_count = sum(1 for t in task_sequence if t == 0)
    interference_count = sum(1 for t in task_sequence if t == 1)
    print(f"  Normal mode states: {normal_count}")
    print(f"  Interference mode states: {interference_count}")

    if switch_positions:
        print(f"  Task verification:")
        prev_pos = 0
        for i, switch_pos in enumerate(switch_positions[:3]):  
            task_before = task_sequence[switch_pos-1] if switch_pos > 0 else None
            task_after = task_sequence[switch_pos] if switch_pos < len(task_sequence) else None
            print(f"    Position {prev_pos}-{switch_pos-1}: Task {task_sequence[prev_pos] if prev_pos < len(task_sequence) else 'N/A'}")
            print(f"    Position {switch_pos}: Task changes to {task_after}")
            prev_pos = switch_pos
    
    return states_sequence, task_sequence, switch_positions

def load_input_prompt(pkl_file_path):

    
    with open(pkl_file_path, 'rb') as f:
        data = pickle.load(f)

    if isinstance(data, dict):
        possible_keys = ['prompt', 'text', 'sequence', 'content', 'data', 'full_series', 'full_series_with_switches']
        prompt_text = None
        
        for key in possible_keys:
            if key in data:
                prompt_text = data[key]
                break
        
        if prompt_text is None:
            print(f"Data dictionary keys: {list(data.keys())}")
            for value in data.values():
                if isinstance(value, str) and len(value) > 100:
                    prompt_text = value
                    break
    elif isinstance(data, str):
        prompt_text = data
    else:
        raise ValueError(f"Cannot extract prompt text from pkl file, data type: {type(data)}")
    
    if prompt_text is None:
        raise ValueError("Could not find valid prompt text in pkl file")
    
    print(f"Successfully loaded prompt, length: {len(prompt_text)} characters")
    return prompt_text

def generate_test_data(P_matrix, num_samples=10, seq_length=50):
    test_pairs = []
    for _ in range(num_samples):
        sequence = generate_markov_sequence(P_matrix, seq_length)
        for i in range(len(sequence) - 1):
            current_state = sequence[i]
            next_state = sequence[i + 1]
            test_pairs.append((current_state, next_state))
    return test_pairs

def generate_markov_sequence(transition_matrix, length):
    N_state = transition_matrix.shape[0]
    current_state = np.random.randint(0, N_state)
    sequence = [current_state]
    
    for _ in range(length - 1):
        probs = transition_matrix[current_state]
        next_state = np.random.choice(N_state, p=probs)
        sequence.append(next_state)
        current_state = next_state
    
    return sequence

def prepare_fisher_data_for_completed_task(states_sequence, task_sequence, switch_positions, 
                                          current_position, target_task_type):
    fisher_data = []

    task_start_pos = 0

    for switch_pos in reversed(switch_positions):
        if switch_pos < current_position:
            if switch_pos < len(task_sequence) and task_sequence[switch_pos] == target_task_type:
                task_start_pos = switch_pos
                break
    else:
        for i in range(current_position):
            if task_sequence[i] == target_task_type:
                task_start_pos = i
                break
    
    print(f"Computing Fisher for task type {target_task_type}, data range: {task_start_pos} to {current_position}")

    for i in range(task_start_pos, current_position):
        if i >= 0 and i + 1 < len(states_sequence):
            current_state = states_sequence[i]
            current_task = task_sequence[i]
            true_next_state = states_sequence[i + 1]
            if current_task == target_task_type:
                if (isinstance(current_state, (int, np.integer)) and 
                    isinstance(current_task, (int, np.integer)) and 
                    isinstance(true_next_state, (int, np.integer))):
                    
                    fisher_data.append((
                        torch.tensor(current_state, dtype=torch.long),
                        torch.tensor(current_task, dtype=torch.long),
                        torch.tensor(true_next_state, dtype=torch.long)
                    ))
    
    print(f"Prepared {len(fisher_data)} samples for Fisher computation from task type {target_task_type}")
    return fisher_data

def run_forgetting_experiment_with_online_training(states_sequence, task_sequence, switch_positions, 
                                                  P_matrix, INTER_P_matrix, learning_rate=0.001, 
                                                  num_states=4, num_tasks=2, device=None,
                                                  lambda_ewc=1000.0):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"\nStarting EWC forgetting experiment with online training...")
    print(f"Device: {device}")
    print(f"Sequence length: {len(states_sequence)}")
    print(f"Learning rate: {learning_rate}")
    print(f"EWC lambda: {lambda_ewc}")
    print(f"Switch positions: {switch_positions}")
    print(f"Evaluation metric: Bhattacharyya distance")
    

    torch.manual_seed(42)  
    model = MarkovPredictor(
        num_states=num_states,
        num_tasks=num_tasks,
        hidden_dim=64,
        use_embedding=True
    ).to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    fisher_calculator = FisherInformationCalculator(model, device)
    ewc_regularizer = EWCRegularizer(model, lambda_ewc, device)

    P_tensor = torch.tensor(P_matrix, dtype=torch.float32, device=device)
    INTER_P_tensor = torch.tensor(INTER_P_matrix, dtype=torch.float32, device=device)

    distances = []
    predictions = []
    true_next_states = []
    positions = []
    training_losses = []
    ewc_losses = []
    
    print(f"\nStarting step-by-step prediction and training...")
    
    for i in range(len(states_sequence) - 1):
        current_state = states_sequence[i]
        current_task = task_sequence[i]
        true_next_state = states_sequence[i + 1]

        if i in switch_positions and i > 50:  
            print(f"\n=== Task switch detected at position {i} ===")

            previous_task_type = task_sequence[i - 1] if i > 0 else task_sequence[0]

            fisher_data = prepare_fisher_data_for_completed_task(
                states_sequence, task_sequence, switch_positions, i, previous_task_type
            )
            
            if fisher_data and len(fisher_data) > 10:  
                print(f"Computing Fisher information for completed task {previous_task_type} with {len(fisher_data)} samples...")
                fisher_dict = fisher_calculator.compute_fisher_information(
                    fisher_data, criterion, num_samples=len(fisher_data)
                )
                ewc_regularizer.save_task_params_by_type(fisher_dict, previous_task_type)
                
                print(f"EWC: Task types saved: {ewc_regularizer.get_saved_task_types()}")
            else:
                print(f"Warning: Insufficient data ({len(fisher_data) if fisher_data else 0} samples) for Fisher computation at position {i}")

        model.train()
        optimizer.zero_grad()

        state_tensor = torch.tensor([current_state], device=device)
        task_tensor = torch.tensor([current_task], device=device)
        target_tensor = torch.tensor([true_next_state], device=device)
        
        logits = model(state_tensor, task_tensor)
        current_loss = criterion(logits, target_tensor)
        ewc_loss = ewc_regularizer.compute_ewc_loss()
        total_loss = current_loss + ewc_loss

        total_loss.backward()
        optimizer.step()

        training_losses.append(current_loss.item())
        ewc_losses.append(ewc_loss.item() if isinstance(ewc_loss, torch.Tensor) else ewc_loss)

        test_data = generate_test_data(P_matrix, num_samples=1, seq_length=50)
        total_distance = 0.0
        for current_state_test, true_next_state_test in test_data:
            probs = model.predict_probs(current_state_test, task_id=0)
            distance = bhattacharyya_distance_corrected(
                probs, current_state_test, true_next_state_test, 
                0, P_tensor, INTER_P_tensor)
            total_distance += distance
        distance = total_distance / len(test_data)

        distances.append(distance)
        true_next_states.append(true_next_state)
        positions.append(i)
        if (i + 1) % 100 == 0 or i == 0:
            print(f"Position {i+1:5d}: State{current_state}(Task{current_task}) -> "
                  f"True{true_next_state}, "
                  f"Distance{distance:.4f}, Loss{current_loss.item():.4f}, "
                  f"EWC{ewc_loss:.4f}, TaskTypes{ewc_regularizer.get_saved_task_types()}")
    
    print(f"Training and prediction completed!")

    window_size = 100
    smoothed_distances = []
    
    for i in range(len(distances)):
        start_idx = max(0, i - window_size + 1)
        end_idx = i + 1
        window_dist = np.mean(distances[start_idx:end_idx])
        smoothed_distances.append(window_dist)

    switch_analysis = []
    for switch_pos in switch_positions:
        if switch_pos < len(distances):
            before_window = 50
            after_window = 50
            
            before_start = max(0, switch_pos - before_window)
            before_end = switch_pos
            after_start = switch_pos
            after_end = min(len(distances), switch_pos + after_window)
            
            if before_end > before_start and after_end > after_start:
                dist_before = np.mean(distances[before_start:before_end])
                dist_after = np.mean(distances[after_start:after_end])
                loss_before = np.mean(training_losses[before_start:before_end])
                loss_after = np.mean(training_losses[after_start:after_end])
                ewc_before = np.mean(ewc_losses[before_start:before_end])
                ewc_after = np.mean(ewc_losses[after_start:after_end])
                
                switch_analysis.append({
                    'position': switch_pos,
                    'task_before': task_sequence[max(0, switch_pos-1)],
                    'task_after': task_sequence[min(len(task_sequence)-1, switch_pos)],
                    'distance_before': dist_before,
                    'distance_after': dist_after,
                    'distance_change': dist_after - dist_before,
                    'loss_before': loss_before,
                    'loss_after': loss_after,
                    'loss_change': loss_after - loss_before,
                    'ewc_before': ewc_before,
                    'ewc_after': ewc_after,
                    'ewc_change': ewc_after - ewc_before
                })

    overall_distance = np.mean(distances)
    overall_loss = np.mean(training_losses)
    overall_ewc_loss = np.mean(ewc_losses)

    result = {
        'positions': positions,
        'distances': distances
    }

    detailed_results = {
        'positions': positions,
        'distances': distances,
        'smoothed_distances': smoothed_distances,
        'true_next_states': true_next_states,
        'switch_positions': switch_positions,
        'switch_analysis': switch_analysis,
        'overall_distance': overall_distance,
        'training_losses': training_losses,
        'ewc_losses': ewc_losses,
        'overall_loss': overall_loss,
        'overall_ewc_loss': overall_ewc_loss,
        'learning_rate': learning_rate,
        'lambda_ewc': lambda_ewc,
        'states_sequence': states_sequence,
        'task_sequence': task_sequence,
        'model': model,
        'num_states': num_states,
        'num_tasks': num_tasks,
        'P_matrix': P_matrix,
        'INTER_P_matrix': INTER_P_matrix,
        'ewc_regularizer': ewc_regularizer,
        'fisher_calculator': fisher_calculator
    }
    
    print(f"\nExperiment statistics:")
    print(f"  Total predictions: {len(distances):,}")
    print(f"  Average Bhattacharyya distance: {overall_distance:.4f}")
    print(f"  Average training loss: {overall_loss:.4f}")
    print(f"  Average EWC loss: {overall_ewc_loss:.4f}")
    print(f"  Distance std: {np.std(distances):.4f}")
    print(f"  Min distance: {min(distances):.4f}")
    print(f"  Max distance: {max(distances):.4f}")
    print(f"  Number of saved task types: {ewc_regularizer.get_num_saved_tasks()}")
    print(f"  Saved task types: {ewc_regularizer.get_saved_task_types()}")
    
    return result, detailed_results

def main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path=None):

    continuous_series_task, markov_chain_task = load_prompt_files(
        generated_series_dir, continuous_series_names, markov_chain_names
    )
    
    results_collection = {}

    for series_name, series_dict in sorted(markov_chain_task.items()):
        print(f"\nProcessing {series_name}")
        
        try:
            full_series = series_dict.get('full_series_with_switches', 
                                         series_dict.get('full_series', ''))
            
            if not full_series:
                print(f"Warning: No full_series found in {series_name}")
                continue
                
            P, INTER_P, states, inter_states = extract_transition_matrices(series_dict)
            
            print(f"Full series length: {len(full_series)}")
            print(f"P matrix shape: {np.array(P).shape}")
            print(f"INTER_P matrix shape: {np.array(INTER_P).shape}")
            states_sequence, task_sequence, switch_positions = parse_input_prompt(full_series)
            
            if not states_sequence:
                print(f"Warning: No states sequence parsed from {series_name}")
                continue
            results, detailed_results = run_forgetting_experiment_with_online_training(
                states_sequence, task_sequence, switch_positions,
                P, INTER_P, learning_rate=0.001, num_states=len(states), num_tasks=2,
                lambda_ewc=700.0
            )
            results_collection[series_name] = results
            data_name = Path(series_name).stem
            pkl_path = os.path.join(save_path, f"{data_name}_results.pkl")
            with open(pkl_path, "wb") as f:
                pickle.dump(results, f)
            positions = results['positions']
            distances = results['distances']
            training_losses = detailed_results['training_losses']
            ewc_losses = detailed_results['ewc_losses']

            fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))

            ax1.plot(positions, distances, 'b-', linewidth=1, alpha=0.7, label='Bhattacharyya Distance')

            for switch_pos in switch_positions:
                ax1.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax1.set_xlabel("Positions")
            ax1.set_ylabel("Distances")
            ax1.set_yscale('log')
            ax1.set_title("Distances vs. Positions (Log Scale) - EWC Method (Task-Type Based)")
            ax1.grid(True, which="both", ls="--", linewidth=0.5)
            ax1.legend()

            ax2.plot(positions, training_losses, 'g-', linewidth=1, alpha=0.7, label='Training Loss')

            for switch_pos in switch_positions:
                ax2.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax2.set_xlabel("Positions")
            ax2.set_ylabel("Training Loss")
            ax2.set_title("Training Loss Over Time - EWC Method (Task-Type Based)")
            ax2.grid(True, which="both", ls="--", linewidth=0.5)
            ax2.legend()

            ax3.plot(positions, ewc_losses, 'orange', linewidth=1, alpha=0.7, label='EWC Regularization Loss')

            for switch_pos in switch_positions:
                ax3.axvline(x=switch_pos, color='red', linestyle='--', alpha=0.7)
            
            ax3.set_xlabel("Positions")
            ax3.set_ylabel("EWC Loss")
            ax3.set_title("EWC Regularization Loss Over Time - Task-Type Based")
            ax3.grid(True, which="both", ls="--", linewidth=0.5)
            ax3.legend()

            plt.tight_layout()
            fig_path = os.path.join(save_path, f"{data_name}.png")
            plt.savefig(fig_path, dpi=300)
            plt.close()

            print(f"Completed processing {series_name}")
            
        except Exception as e:
            print(f"Error processing {series_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\n{'='*80}")
    print(f"Directory processing completed!")
    print(f"Successfully processed {len(results_collection)} files")
    print(f"{'='*80}")
    
    return results_collection


    
    
    
if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(parent_dir)
    save_path = Path(parent_dir) / 'icl_run_results_ewc_8_sp'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    generated_series_dir = Path(parent_dir) / 'data_gen_600' / 'generated_series_er_8_sp'
    continuous_series_names = ["continuous_series"]  
    markov_chain_names = ["markov_chain"]  
    results_collection = main_with_directory(generated_series_dir, continuous_series_names, markov_chain_names, save_path)
    
    print("\nExperiment completed!")