import os
import pandas as pd

from sklearn.decomposition import PCA
from openTSNE import TSNE
import umap as up
import torch

import hdbscan
import itertools


# Return pandas df of WSJ data
def get_WSJ(csv_path = "rsfinkit/data/embeddings/data/wsj_compiled_data.csv"):
    if not os.path.exists(csv_path):
        return _compile_WSJ_data(csv_path)
    else:
        return pd.read_csv(csv_path)
    
# Takes the WSJ headline datra and compiles them into one .csv file
# wsj_headlines
#   /2010.csv
#   /2011.csv
#   /2012.csv
#   ...
#   /2023.csv
def _compile_WSJ_data(csv_path):
    WSJ_HEADLINES_PATH = "rsfinkit/data/wsj_headlines"

    # Initialize an empty list to store the dataframes
    dataframes = []

    # Loop through the files in the directory
    for filename in sorted(os.listdir(WSJ_HEADLINES_PATH)):
        if filename.endswith('.csv') and filename[:4].isdigit():
            # Construct the full file path
            filepath = os.path.join(WSJ_HEADLINES_PATH, filename)
            # Read the CSV file and append it to the list
            df = pd.read_csv(filepath)

            # Drop rows where any of the data is missing
            df.dropna(inplace=True)

            # Append the dataframe to the list if it's not empty
            if not df.empty:
                dataframes.append(df)

    # Concatenate all the dataframes in the list
    concatenated_df = pd.concat(dataframes, ignore_index=True)

    # Drop duplicates based on the "news" column, keeping the first occurrence
    concatenated_df.drop_duplicates(subset='News', keep='first', inplace=True)

    # Reset index to add an 'id' column
    concatenated_df.reset_index(drop=False, inplace=True)
    concatenated_df.rename(columns={'index': 'id'}, inplace=True)

    # Save the concatenated dataframe to a new CSV file
    concatenated_df.to_csv(csv_path, index=False)

    return concatenated_df

def get_stock_movement(ods_path = "rsfinkit/data/embeddings/data/finance_data.ods", tags = None):
    # Load the ods file
    df = pd.read_excel(ods_path, engine="odf")

    return df

def cluster_data(embeddings, cluster_size=10):
    # Cluster the data using K-Means
    clusterer = hdbscan.HDBSCAN(min_cluster_size=cluster_size)
    clusters = clusterer.fit_predict(embeddings)
    return clusters

# Returns pca, tsne, and umap embeddings
def reduce_dimensions(X, path, n_components = 2, method='all'):
    if method in ["pca", "tsne", "umap"] and os.path.exists(path):
        return torch.load(path)
    
    if method == 'pca':
        # Reduce dimensions using PCA
        # Reduce dimensions using all three methods
        pca = PCA(n_components=n_components)
        pca_trans = pca.fit_transform(X)
        
        # Save the reduced dimensions
        torch.save(pca_trans, path, pickle_protocol=4)
        
        return pca_trans
    elif method == 'tsne':
        # Reduce dimensions using t-SNE
        tsne = TSNE(
            n_components=n_components, 
            n_jobs=-1,  # Use all available cores
            perplexity=30,  # Default perplexity, adjust as needed
            initialization="pca",  # Initialize with PCA to speed up
            metric="euclidean"  # Default metric, can be changed based on your data
        )
        tsne_trans = tsne.fit(X)
        
        torch.save(tsne_trans, path, pickle_protocol=4)
        
        return tsne_trans
    elif method == 'umap':
        # Reduce dimensions using UMAP
        model = up.UMAP(n_components=n_components)
        umap_trans = model.fit_transform(X)
        
        torch.save(umap_trans, path, pickle_protocol=4)
        
        return umap_trans
    elif method == 'all':
        # Reduce dimensions using all three methods
        pca = reduce_dimensions(X, n_components=n_components, method='pca')
        tsne = reduce_dimensions(X, n_components=n_components, method='tsne')
        umap = reduce_dimensions(X, n_components=n_components, method='umap')

        return pca, tsne, umap
    else:
        raise ValueError("Invalid method. Choose 'pca', 'tsne', or 'umap'.")

def generate_runs(hyperparameters):
    # Generate all possible combinations of hyperparameters
    # https://stackoverflow.com/questions/3873654/combinations-from-dictionary-with-list-values-using-python
    keys, values = zip(*hyperparameters.items())
    runs = [dict(zip(keys, v)) for v in itertools.product(*values)]
    
    # Remove illegal pairs
    # 1) task_type == "tag_pred" and tags in ["location", "genre"]
    # 2) task_type == "finance" and tags == "finance"
    runs_tag_pred = [run for run in runs if (run["task_type"] == "tag_pred" and run["tags"] in ["location", "genre"])]
    runs_tag_finance = [run for run in runs if (run["task_type"] == "finance" and run["tags"] == "finance")]
    runs_tag_pred.extend(runs_tag_finance)
    runs = runs_tag_pred
    

    # Add a unique ID to each run
    for i, run in enumerate(runs):
        run["id"] = i

    return runs