import os
import subprocess
import hashlib
import os
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import pandas as pd

def generate_version_name_finetune(finetune_pretrained, freeze_layers, token_mask_ratio, channel_mask_ratio, end=False):
    if not end:
        hyperparams = f"pretrained={finetune_pretrained}_freeze_layers={freeze_layers}_token_mask_ratio={token_mask_ratio}_channel_mask_ratio={channel_mask_ratio}"
    else:
        hyperparams = f"pretrained={finetune_pretrained}_freeze_layers={freeze_layers}_token_mask_ratio={token_mask_ratio}_channel_mask_ratio={channel_mask_ratio}_test"
    return hashlib.md5(hyperparams.encode()).hexdigest()

def generate_version_name(token_mask_ratio, channel_mask_ratio):
    hyperparams = f"token_mask_ratio={token_mask_ratio}_channel_mask_ratio={channel_mask_ratio}"
    return hashlib.md5(hyperparams.encode()).hexdigest()

def read_tfevents(path, tag):
    event_files = [os.path.join(path, f) for f in os.listdir(path) if f.startswith('events.out.tfevents')]
    
    if not event_files:
        return []
    
    best_event_file = None
    max_events = 0
    
    for event_file in event_files:
        event_acc = EventAccumulator(event_file)
        event_acc.Reload()
        
        if tag in event_acc.Tags()['scalars']:
            events = event_acc.Scalars(tag)
            if len(events) > max_events:
                max_events = len(events)
                best_event_file = event_file
    
    if best_event_file is None:
        return []

    event_acc = EventAccumulator(best_event_file)
    event_acc.Reload()
    events = event_acc.Scalars(tag)
    return [event.value for event in events]

def get_last_epoch_from_events(event_file):
    event_acc = EventAccumulator(event_file)
    event_acc.Reload()
    
    last_epoch = 0
    if 'epoch' in event_acc.Tags()['scalars']:
        events = event_acc.Scalars('epoch')
        last_epoch = max(event.value for event in events)
    
    return last_epoch

TF_EVENTS_PATH = '/home/thoriri/projects/data/LightningIR/experiments/'

token_mask_ratio = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
channel_mask_ratio = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]

pretrain_tag = 'TSTT_pretrain'
finetune_tag = 'TSTT_finetune'
pretrain_tag_new = 'TSTT_pretrain_new'
finetune_tag_new = 'TSTT_finetune_new'


hash_list = os.listdir(os.path.join(TF_EVENTS_PATH, pretrain_tag))
hash_list_new = os.listdir(os.path.join(TF_EVENTS_PATH, pretrain_tag_new))
hash_list_finetune = os.listdir(os.path.join(TF_EVENTS_PATH, finetune_tag))
hash_list_finetune_new = os.listdir(os.path.join(TF_EVENTS_PATH, finetune_tag_new))

run_dict = {}
for tok in token_mask_ratio:
    for cha in channel_mask_ratio:
        hash_pretrain = generate_version_name(tok, cha)
        if hash_pretrain in hash_list or hash_pretrain in hash_list_new:
            if(hash_pretrain in hash_list):
                model_path = pretrain_tag
            else:
                model_path = pretrain_tag_new
            events_path = os.path.join(TF_EVENTS_PATH, model_path, hash_pretrain)
            run_dict[(tok, cha)] = {
                'pretrained': True,
                'finetuned': False,
                'pretrain_epoch': max([get_last_epoch_from_events(os.path.join(events_path, f)) for f in os.listdir(events_path) if f.startswith('events.out.tfevents')]),
                'finetune epoch': 0,
                'test_balanced_acc': None
            }

        else:
            run_dict[(tok, cha)] = {
                'pretrained': False,
                'finetuned': False,
                'pretrain_epoch': 0,
                'finetune epoch': 0,
                'test_balanced_acc': None
            }

        hash_finetune = generate_version_name_finetune(True, False, tok, cha, end=False)
        if(tok > 0.3 or cha > 0.3):
            hash_finetune_new = generate_version_name_finetune(True, False, tok, cha, end=True)
        else:
            hash_finetune_new = 'ASDASDA'
        if hash_finetune in hash_list_finetune or hash_finetune_new in hash_list_finetune_new:
            if(hash_finetune in hash_list_finetune):
                model_path = finetune_tag
                events_path = os.path.join(TF_EVENTS_PATH, model_path, hash_finetune)
            else:
                model_path = finetune_tag_new
                events_path = os.path.join(TF_EVENTS_PATH, model_path, hash_finetune_new)
            run_dict[(tok, cha)]['finetuned'] = True
            run_dict[(tok, cha)]['finetune epoch'] = max([get_last_epoch_from_events(os.path.join(events_path, f)) for f in os.listdir(events_path) if f.startswith('events.out.tfevents')])
# Convert the run_dict to a pandas DataFrame
df = pd.DataFrame.from_dict(run_dict, orient='index').reset_index()
df.columns = ['token_mask_ratio', 'channel_mask_ratio', 'pretrained', 'finetuned', 'pretrain_epoch', 'finetune_epoch', 'test_balanced_acc']

# Trim df so that lines where pretrained and finetuned are False are taken away:

df = df[(df['pretrained'] == True) | (df['finetuned'] == True)]

# Reset the index
df.reset_index(drop=True, inplace=True)

# Loop through the rows of the DataFrame and if pretrained is True and pretrain_epoch is less than 150 thne put it to "Partial"
# Same with finetuned and finetune_epoch < 150

for i, row in df.iterrows():
    if row['pretrained'] and row['pretrain_epoch'] < 150:
        df.at[i, 'pretrained'] = 'Partial'
    if row['finetuned'] and row['finetune_epoch'] < 150:
        df.at[i, 'finetuned'] = 'Partial'

# Next if pretrained is True and finetuned is True then we look into the "experiment_outputs/" folder and check if "output_token_..._channel_....txt" exists
# If it does we read that txt and look for the line '│     test_balanced_acc     │    0.7821017503738403     │' and put the value in the 'test_balanced_acc' column
# If it doesn't exist or the line is not found then we put '~' in the 'test_balanced_acc' column

for i, row in df.iterrows():
    if row['pretrained'] == True and row['finetuned'] == True:
        file_path = os.path.join('/','home','thoriri','Documents','TimeFM','experiment_outputs', f'output_token_{row["token_mask_ratio"]}_channel_{row["channel_mask_ratio"]}.txt')
        if os.path.exists(file_path):
            with open(file_path) as f:
                lines = f.readlines()
                for line in lines:
                    if '│     test_balanced_acc     │' in line:
                        df.at[i, 'test_balanced_acc'] = float(line.split()[3])
                        break
        else:
            df.at[i, 'test_balanced_acc'] = '~'
to_run = df[(df['pretrained'] == True) & (df['finetuned'] == True)][['token_mask_ratio', 'channel_mask_ratio']]


# Define the values for token_mask_ratio and channel_mask_ratio
#token_mask_ratios = [0.1, 0.2, 0.3]
#channel_mask_ratios = [0.1, 0.2, 0.3]

# Define the base command
base_command = "python run_test.py +experiment=TSTT_test freeze_layers=False finetune_pretrained=True gpus=4 batch_size=512"

# Create a directory to save the outputs
output_dir = "experiment_outputs"
os.makedirs(output_dir, exist_ok=True)

# Iterate over the combinations of token_mask_ratio and channel_mask_ratio
for token_mask_ratio, channel_mask_ratio in to_run.itertuples(index=False):
    # Check if the output file already exists
    output_file = os.path.join(output_dir, f"output_token_{token_mask_ratio}_channel_{channel_mask_ratio}.txt")
    if os.path.exists(output_file):
        print(f"Output file {output_file} already exists. Skipping...")
        continue
    # Construct the command
    command = f"{base_command} token_mask_ratio={token_mask_ratio} channel_mask_ratio={channel_mask_ratio}"
    print(f"Running: {command}")
    
    # Run the command and capture the output
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    
    # Define the output file name
    output_file = os.path.join(output_dir, f"output_token_{token_mask_ratio}_channel_{channel_mask_ratio}.txt")
    
    # Save the output to a file
    with open(output_file, "w") as f:
        f.write(result.stdout)
        f.write(result.stderr)

    print(f"Output saved to: {output_file}")