### Set up directory
import gc
import sys
import os
import matplotlib.pyplot as plt
from pathlib import Path
import pickle
from datetime import datetime
import numpy as np
from tqdm import tqdm
import torch
from models.llm import get_model_and_tokenizer
from models.icl_perform import icl_markov_evaluation
markov_chain_names = ['markov_chain']
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
save_path = Path(parent_dir) / 'icl_run_results_deepseek_5_1'
if not os.path.exists(save_path):
    os.makedirs(save_path)
generated_series_dir = Path(parent_dir) / 'data' / 'generated_series'

model, tokenizer = get_model_and_tokenizer('xxx')
print("Tokenizer limit:", tokenizer.model_max_length)
print("Model limit:", model.config.max_position_embeddings)

# Initialize dictionaries to store the data for continuous series and Markov chains
continuous_series_task = {}
markov_chain_task = {}
print("generated_series_dir:", generated_series_dir, "save_path:", save_path)
for file in generated_series_dir.iterdir():
    if not (save_path / file.name).exists():\

        series_name = '_'.join(file.stem.split('_')[:2])
        if series_name in markov_chain_names:
            markov_chain_task[file.name] = pickle.load(file.open('rb'))
            print("Markov chain:", file.name)
        else:
            raise Exception(f"Unrecognized series name: {series_name}")    
print(generated_series_dir)    
print(continuous_series_task.keys())
print(markov_chain_task.keys())

         
        
for series_name, series_dict in sorted(markov_chain_task.items()):
    print("Processing ", series_name)
    full_series = series_dict['full_series_with_switches'] if 'full_series_with_switches' in series_dict else full_series
    llm_name = series_dict['llm_name']
    print("full series:", full_series)
    results = icl_markov_evaluation(model, tokenizer, full_series,series_dict['states'], series_dict['P'], 
                                    series_dict['INTER_states'], series_dict['INTER_P'],)
    print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Results for {series_name}:")
    data_name = Path(series_name).stem
    pkl_path = os.path.join(save_path, f"{data_name}_results.pkl")
    with open(pkl_path, "wb") as f:
        pickle.dump(results, f)

    positions = results['positions']
    distances = results['distances']

    plt.figure(figsize=(10, 6))
    plt.plot(positions, distances, marker='o', linestyle='-')

    plt.xlabel("Positions")
    plt.ylabel("Distances")
    plt.yscale('log')  
    plt.title("Distances vs. Positions (Log Scale)")
    plt.grid(True, which="both", ls="--", linewidth=0.5)
    plt.tight_layout()


    fig_path = os.path.join(save_path, f"{data_name}.png")
    plt.savefig(fig_path, dpi=300)

