#!/usr/bin/env python3

from cls_exp import cls_exp
import os
import pandas as pd
from datetime import datetime

def get_subfolders(folder_path):
    """
    Returns a list of subfolder names in the given folder.
    """
    return [
        name for name in os.listdir(folder_path)
        if os.path.isdir(os.path.join(folder_path, name))
    ]

def extract_metadata(ts_file_path):
    """Extract metadata from a .ts file."""
    metadata = {}
    with open(ts_file_path, 'r') as file:
        metadata = _load_header_info(file)
    return metadata

def _load_header_info(file):
    """Load the meta data from a .ts file and advance file to the data.

    Parameters
    ----------
    file : stream.
        input file to read header from, assumed to be just opened

    Returns
    -------
    meta_data : dict.
        dictionary with the data characteristics stored in the header.
    """
    meta_data = {
        "problemname": "none",
        "timestamps": False,
        "missing": False,
        "univariate": True,
        "equallength": True,
        "classlabel": True,
        "targetlabel": False,
        "class_values": [],
    }
    boolean_keys = ["timestamps", "missing", "univariate", "equallength", "targetlabel"]
    for line in file:
        line = line.strip().lower()
        if line and not line.startswith("#"):
            tokens = line.split(" ")
            token_len = len(tokens)
            key = tokens[0][1:]
            if key == "data":
                if line != "@data":
                    raise IOError("data tag should not have an associated value")
                return meta_data
            if key in meta_data.keys():
                if key in boolean_keys:
                    if token_len != 2:
                        raise IOError(f"{tokens[0]} tag requires a boolean value")
                    if tokens[1] == "true":
                        meta_data[key] = True
                    elif tokens[1] == "false":
                        meta_data[key] = False
                elif key == "problemname":
                    meta_data[key] = tokens[1]
                elif key == "classlabel":
                    if tokens[1] == "true":
                        meta_data["classlabel"] = True
                        if token_len == 2:
                            raise IOError(
                                "if the classlabel tag is true then class values "
                                "must be supplied"
                            )
                    elif tokens[1] == "false":
                        meta_data["classlabel"] = False
                    else:
                        raise IOError("invalid class label value")
                    meta_data["class_values"] = [token.strip() for token in tokens[2:]]
        if meta_data["targetlabel"]:
            meta_data["classlabel"] = False
    return meta_data

def find_matching_datasets(parent_folder):
    """Find datasets with @univariate true and @equalLength true in subfolders."""
    matching_datasets = []
    
    for dataset_name in os.listdir(parent_folder):
        dataset_path = os.path.join(parent_folder, dataset_name)
        if os.path.isdir(dataset_path):  # Ensure it's a folder
            for filename in os.listdir(dataset_path):
                if filename.endswith(".ts"):
                    ts_file_path = os.path.join(dataset_path, filename)
                    metadata = extract_metadata(ts_file_path)
                    
                    if metadata.get('univariate') == True and metadata.get('equallength') == True:
                        matching_datasets.append(dataset_name)
                        break  # No need to check more files in the dataset
    
    return matching_datasets



def main():
    # # Get subfolders from the specified folder
    # Example usage
    parent_folder = '../data/Timeseries-PILE/classification/UCR' # Replace with the correct path
    matching_datasets = find_matching_datasets(parent_folder)
    print("Matching datasets:", matching_datasets)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")



    dataset_names = ['FordA', 'FordB', 'ECG200', 'ElectricDevices', 'SmallKitchenAppliances', 'SwedishLeaf']
    # dataset_names = ['FordA']

    for d in dataset_names:
        assert d in matching_datasets, f"Dataset {d} not found in matching datasets."

    pooling_methods = []
    pooling_methods.extend( ['max', 'sum', 'mean', 'last', 'attention'])
    pooling_methods.extend( ['weighted_average'])

    seed_list = [42, 43, 44, 45, 46]

    model_size_list = ['small', 'base', 'large']
    # model_size_list = ['small']
    optimizer_name_list = ['adam']

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    result_folder = './results/classification/'
    lr = 1e-3
    
    # Run experiments
    results = cls_exp.run_exp(
                            pooling_methods=pooling_methods, 
                            dataset_names=dataset_names, 
                            seed_list=seed_list, 
                            optimizer_name_list=optimizer_name_list, 
                            model_size_list=model_size_list, 
                            train_epoch=20,
                            use_mixed_precision=True,
                            output_path=f"{result_folder}/exp_results_{timestamp}.jsonl",
                            lr=lr,
                        )

    # Process and save results
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values(by=['dataset_name', 'optimizer_name'], ascending=False)

    
    csv_file_name = f'{result_folder}/cls_results_{timestamp}.csv'
    results_df.to_csv(csv_file_name, index=False)
    print(f"Results saved to {csv_file_name}")

if __name__ == '__main__':
    main()
