### Set up directory
import sys
import os
parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

# Check if directory exists, if not create it
save_path = os.path.join(parent_dir, 'generated_series_er_8_1')

if not os.path.exists(save_path):
    os.makedirs(save_path)


    
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import os
from data.serialize import serialize_arr, deserialize_str, SerializerSettings
import pickle
from scipy.integrate import odeint


def serialize_gaussian(prec, time_series, mean_series, sigma_series):

    settings=SerializerSettings(base=10, prec=prec, signed=True, time_sep=',', bit_sep='', minus_sign='-', fixed_length=False, max_val = 10)
    time_series = np.array(time_series)
    rescale_factor = 0.7
    up_shift = 0.15
    rescaled_array = (time_series-time_series.min())/(time_series.max()-time_series.min()) * rescale_factor + up_shift
    rescaled_true_mean_arr = (np.array(mean_series)-time_series.min())/(time_series.max()-time_series.min()) * rescale_factor + up_shift
    rescaled_true_sigma_arr = np.array(sigma_series)/(time_series.max()-time_series.min()) * rescale_factor 
    rescaled_true_mean_arr *= 10
    rescaled_true_sigma_arr *= 10
    full_series = serialize_arr(rescaled_array, settings)
    return (full_series, rescaled_true_mean_arr, rescaled_true_sigma_arr)

def generate_transition_matrix(N_state):
    P = np.random.rand(N_state, N_state)
    P /= P.sum(axis=1)[:, np.newaxis]
    return P

def generate_gaussian_matrix(N_state, sigma = 0.5):

    bins = np.linspace(-3, 3, N_state)
    gaussian_distribution = np.exp(-0.5 * (bins / sigma) ** 2)
    gaussian_distribution /= gaussian_distribution.sum()  
    P = np.tile(gaussian_distribution, (N_state, 1))
    return P


num_series = 1
llm_name = 'er'
llm_name_list = ['er','er','er','er','er','er','er','er','er','er','er','er','er','er','er','er']
Nt_list = [1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240,1240]  
N_state_list = [8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8]  
random_seed_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]  
multi_time_learning = True  

traj_name = 'markov_chain'
for llm_name, Nt, N_state, random_seed in zip(llm_name_list, Nt_list,
N_state_list, random_seed_list):
    states = np.arange(N_state)
    INTER_states = np.arange(N_state) 
    chain = [0]
    np.random.seed(random_seed)
    P = generate_transition_matrix(N_state)
    INTER_P = generate_gaussian_matrix(N_state)


    task_switches = []
    current_task = "normal"  
    if multi_time_learning:
        
        for t in range(1, Nt):
            current_state = chain[-1]
            if  100< t < 110 or 210 < t < 220 or 320 < t < 330 or 430< t < 440 or t > 540: 
                print("multi noise mode activated")
                if current_task == "normal":
                    task_switches.append({"position": t, "switch_to": "noise"})
                    current_task = "noise"
                next_state = np.random.choice(INTER_states, p=INTER_P[current_state])
            else:
                if current_task == "noise":
                    task_switches.append({"position": t, "switch_to": "normal"})
                    current_task = "normal"
                next_state = np.random.choice(states, p=P[current_state])
            chain.append(next_state)
        
        
    else:
        
        for t in range(1, Nt):
            current_state = chain[-1]
            if 500 <= t < 500 + 3600:
                print("noise mode activated")
                if current_task == "normal":
                    task_switches.append({"position": t, "switch_to": "noise"})
                    current_task = "noise"
                next_state = np.random.choice(INTER_states, p=INTER_P[current_state])
            else:
                if current_task == "noise":
                    task_switches.append({"position": t, "switch_to": "normal"})
                    current_task = "normal"
                next_state = np.random.choice(states, p=P[current_state])
            chain.append(next_state)
    
    task_description = f"""Learn pattern, predict next state. [SWITCH_TO_INTERFERENCE] = interference pattern starts. [SWITCH_TO_NORMAL] = target pattern resumes. Predict next: """

    full_series_with_switches = task_description
    switch_index = 0
    
    for i, state in enumerate(chain):
        if switch_index < len(task_switches) and task_switches[switch_index]["position"] == i:
            switch_type = task_switches[switch_index]["switch_to"]
            if switch_type == "noise":
                full_series_with_switches += "[SWITCH_TO_INTERFERENCE]"
            else:
                full_series_with_switches += "[SWITCH_TO_NORMAL]"
            switch_index += 1
        full_series_with_switches += str(state)
    
    # Convert the chain list to a string and store it in full_series 
    full_series = "".join(str(x) for x in chain)
    
    data_dict = {
        'full_series': full_series,  
        'full_series_with_switches': full_series_with_switches,  
        'full_array': np.array(chain),
        'llm_name': llm_name,
        'random_seed': random_seed,
        'P': P,
        'task_switches': task_switches,  
        'states': states.tolist(),  
        'INTER_states': INTER_states.tolist(),  
        'INTER_P': INTER_P  
    }
    

    file_indices = sum(
        1 for name in os.listdir(save_path)
        if os.path.isfile(os.path.join(save_path, name)) and name.startswith(traj_name)
    )

    print("file indices:", file_indices, "N_state_list:", N_state_list)
    file_index = 0
    if multi_time_learning:
        file_index = 1
    else:
        file_index = 0
    print("fle_index:", file_index, "N_state_list[file_indices-1]:", N_state_list, "random_seed:", random_seed)
    save_name = os.path.join(save_path, f'{traj_name}_{file_indices}_state_{N_state_list[0]}_multi_{file_index}.pkl')
    with open(save_name, 'wb') as f:
        pickle.dump(data_dict, f)
    print(f"Saved data to {save_name}","save_name shape:", data_dict['full_array'].shape)    