import ast
from sklearn import preprocessing
from scipy.spatial import distance
import tqdm
from scipy.spatial.distance import cdist,pdist
import pandas as pd
from scipy.stats import rankdata
from sklearn.metrics.pairwise import cosine_similarity,euclidean_distances
from sklearn.neighbors import NearestNeighbors
import numpy as np
from scipy.stats import skew
from sklearn.decomposition import PCA
import warnings




def normalize_based_on_all_tasks(TASKS):
    Combined_Dataset = []
    for task_id in TASKS:
        if dataset == 'School':
            csv = (f"{DataPath}/{task_id}_School_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df = df[['1985', '1986', '1987',
                     'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                     'Turkish', 'Other',
                     'VR_Band', 'Gender',
                     'FSM', 'VR_BAND_Student', 'School_Gender', 'Maintained', 'Church', 'Roman_Cath',
                     'ExamScore']]
        if dataset == 'Landmine':
            csv = (f"{DataPath}LandmineData_{task_id}.csv")
            df = pd.read_csv(csv, low_memory=False)
        if dataset == 'Chemical':
            csv = (f"{DataPath}{task_id}_Molecule_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df.loc[df['181'] < 0, '181'] = 0


        Combined_Dataset.append(df)

    '''Data for all tasks combined'''
    Combined_Dataset = pd.concat(Combined_Dataset)
    Combined_Dataset = Combined_Dataset.dropna()
    # '''removing the labels'''
    # # if dataset == 'School':
    # #     X_All = Combined_Dataset.drop(columns=['ExamScore']).values
    # # elif dataset == 'Landmine':
    # #     X_All = Combined_Dataset.drop(columns=['Labels']).values
    # # elif dataset == 'Chemical':
    # #     X_All = Combined_Dataset.drop(columns=['181']).values
    # # print(f'Shape of X is {X_All.shape}')
    X_All = Combined_Dataset.values
    standard_scaler = preprocessing.QuantileTransformer(n_quantiles=30)
    standard_scaler.fit(X_All)
    return standard_scaler

def get_unified_distance_labelVariance(Tasks_list):

    Combined_df = []
    target_columns = []
    target_column_1 = []
    target_column_2 = []
    for task_id in Tasks_list:
        if dataset == 'School':
            csv = (f"{DataPath}/{task_id}_School_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df = df[['1985', '1986', '1987',
                     'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                     'Turkish', 'Other',
                     'VR_Band', 'Gender',
                     'FSM', 'VR_BAND_Student', 'School_Gender', 'Maintained', 'Church', 'Roman_Cath',
                     'ExamScore']]
            target_columns = target_columns + list(df['ExamScore'])
            # df = df.drop(columns=['ExamScore'])

        if dataset == 'Landmine':
            csv = (f"{DataPath}LandmineData_{task_id}.csv")
            df = pd.read_csv(csv, low_memory=False)
            target_columns = target_columns + list(df['Labels'])
            # df = df.drop(columns=['Labels'])

        if dataset == 'Chemical':
            csv = (f"{DataPath}{task_id}_Molecule_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df.loc[df['181'] < 0, '181'] = 0
            target_columns = target_columns + list(df['181'])
            # df = df.drop(columns=['181'])

        Combined_df.append(df)

    Combined_df = pd.concat(Combined_df, axis=0, join='outer', ignore_index=True)
    x = Combined_df.values  # returns a numpy array
    DataSet_Scaled = standard_scaler.transform(x)

    if dataset in ['School', 'Landmine', 'Parkinsons']:
        dists = distance.pdist(DataSet_Scaled, metric='euclidean')
    elif dataset == 'Chemical':
        dists = distance.pdist(DataSet_Scaled, metric='hamming')
    return np.mean(dists),np.var(target_columns),np.std(target_columns)


def fast_energy_distance(X, Y, metric='euclidean'):
    """
    Computes the energy distance between two samples X and Y using faster vectorized computation.
    Optimized for 'euclidean' metric. Falls back to original method for others.

    Parameters:
        X (np.ndarray): Samples from distribution 1, shape (n1, d)
        Y (np.ndarray): Samples from distribution 2, shape (n2, d)
        metric (str): Distance metric ('euclidean' is fastest)

    Returns:
        float: Energy distance
    """
    n1, n2 = len(X), len(Y)
    if metric == 'euclidean':
        # Compute squared norms for fast broadcasting
        X_norm = np.sum(X**2, axis=1).reshape(-1, 1)
        Y_norm = np.sum(Y**2, axis=1).reshape(1, -1)
        # dist_XY = np.sqrt(X_norm + Y_norm - 2 * X @ Y.T)
        D = X_norm + Y_norm - 2 * X @ Y.T
        D = np.maximum(D, 0)  # Clamp to avoid negatives
        dist_XY = np.sqrt(D)

        term_XY = 2 * np.sum(dist_XY) / (n1 * n2)

        term_XX = 2 * np.sum(pdist(X, metric='euclidean')) / (n1 * (n1 - 1))
        term_YY = 2 * np.sum(pdist(Y, metric='euclidean')) / (n2 * (n2 - 1))
    else:
        d_XY = cdist(X, Y, metric=metric)
        term_XY = 2 * np.sum(d_XY) / (n1 * n2)

        d_XX = cdist(X, X, metric=metric)
        np.fill_diagonal(d_XX, 0)
        term_XX = np.sum(d_XX) / (n1 * (n1 - 1))

        d_YY = cdist(Y, Y, metric=metric)
        np.fill_diagonal(d_YY, 0)
        term_YY = np.sum(d_YY) / (n2 * (n2 - 1))

    return term_XY - term_XX - term_YY


def compute_energy_distance(X, Y, metric='euclidean'):
    """
    Computes the energy distance between two samples X and Y.
    Parameters:
        X (np.ndarray): Samples from distribution 1, shape (n1, d)
        Y (np.ndarray): Samples from distribution 2, shape (n2, d)
        metric (str): Distance metric (e.g., 'euclidean', 'manhattan')
    Returns:
        float: Energy distance
    """
    n1, n2 = len(X), len(Y)

    d_XY = cdist(X, Y, metric=metric)
    term_XY = 2 * np.sum(d_XY) / (n1 * n2)

    d_XX = cdist(X, X, metric=metric)
    np.fill_diagonal(d_XX, 0)
    term_XX = np.sum(d_XX) / (n1 * (n1 - 1))

    d_YY = cdist(Y, Y, metric=metric)
    np.fill_diagonal(d_YY, 0)
    term_YY = np.sum(d_YY) / (n2 * (n2 - 1))

    return term_XY - term_XX - term_YY

# Load raw task features for energy distance
def load_task_features(tid, dataset):
    if dataset == 'School':
        path = f"{DataPath}/{tid}_School_Data.csv"
        df = pd.read_csv(path)
        df = df[['1985', '1986', '1987',
                 'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                 'Turkish', 'Other',
                 'VR_Band', 'Gender',
                 'FSM', 'VR_BAND_Student', 'School_Gender', 'Maintained', 'Church', 'Roman_Cath',
                 'ExamScore']]
        return df#.drop(columns=['ExamScore'])
    elif dataset == 'Landmine':
        path = f"{DataPath}/LandmineData_{tid}.csv"
        df = pd.read_csv(path)
        return df#.drop(columns=['Labels'])
    elif dataset == 'Chemical':
        path = f"{DataPath}/{tid}_Molecule_Data.csv"
        df = pd.read_csv(path)
        df.loc[df['181'] < 0, '181'] = 0
        return df#.drop(columns=['181'])



def rank_based_similarity(taskA_df, taskB_df, agg_method='mean'):
    """
    Compute rank-based similarity between two tasks' datasets.

    Parameters:
    - taskA_df: pd.DataFrame of shape (n_A, d) -- task A input features
    - taskB_df: pd.DataFrame of shape (n_B, d) -- task B input features
    - agg_method: str -- how to aggregate per-feature rank gaps ('mean', 'l2', 'max')

    Returns:
    - float: rank-based distance (lower means more similar)
    """

    assert taskA_df.shape[1] == taskB_df.shape[1], "Both tasks must have same number of features"

    # combined_df = pd.concat([taskA_df, taskB_df], axis=0).reset_index(drop=True)
    n_A = len(taskA_df)
    # print(f'n_A: {n_A}, n_B: {len(taskB_df)}')

    rank_gaps = []
    rank_dists = []
    for idx in range(taskA_df.shape[1]):
        # print(f'idx = {idx}, shape: {taskA_df[:, idx].shape} {taskB_df[:, idx].shape}')
        combined_data = np.hstack([taskA_df[:, idx], taskB_df[:, idx]])

        ranks = rankdata(combined_data)
        rA = np.mean(ranks[:n_A])
        rB = np.mean(ranks[n_A:])
        rank_gaps.append(abs(rA - rB))

    # print(f'rank_gaps: {rank_gaps}')
    # Aggregate gap across features
    rank_gaps = np.array(rank_gaps)
    if agg_method == 'mean':
        return np.mean(rank_gaps)
    elif agg_method == 'l2':
        return np.linalg.norm(rank_gaps)
    elif agg_method == 'max':
        return np.max(rank_gaps)
    else:
        raise ValueError("agg_method must be one of ['mean', 'l2', 'max']")



def graph_based_similarity(X1, X2, k=5, metric='euclidean', seed=42):
    """
    Compute graph-based similarity using proportion of cross-dataset edges in a k-NN graph.

    Args:
        X1 (np.ndarray): Samples from dataset 1 (n1 x d)
        X2 (np.ndarray): Samples from dataset 2 (n2 x d)
        k (int): Number of nearest neighbors
        metric (str): Distance metric (e.g., 'euclidean')
        normalize (bool): Whether to apply MinMax scaling to features
        seed (int): Random seed for reproducibility

    Returns:
        float: Proportion of edges that connect points from different datasets (higher = more similar)
    """
    np.random.seed(seed)
    X1 = np.array(X1)
    X2 = np.array(X2)

    n1, n2 = len(X1), len(X2)
    labels = np.concatenate([np.zeros(n1), np.ones(n2)])
    data = np.vstack([X1, X2])


    data = standard_scaler.transform(data)

    nbrs = NearestNeighbors(n_neighbors=k + 1, metric=metric)
    nbrs.fit(data)
    _, indices = nbrs.kneighbors(data)

    cross_edges = 0
    total_edges = 0

    for i, neighbors in enumerate(indices):
        for j in neighbors[1:]:  # skip self
            total_edges += 1
            if labels[i] != labels[j]:
                cross_edges += 1

    similarity_score = cross_edges / total_edges
    return similarity_score


def compute_additional_pairwise_features(X1_scaled, X2_scaled, results, n_pca_components=5, n_bins=30):
    """
    Compute additional pairwise statistics between two multivariate datasets X1 and X2.

    Args:
        X1_scaled (np.ndarray): shape (n1, d)
        X2_scaled (np.ndarray): shape (n2, d)
        results (dict): dictionary to append results into
        n_pca_components (int): number of components for PCA
        n_bins (int): number of bins for JS divergence histograms

    Returns:
        dict: updated results dictionary with appended values
    """
    # ----- Mean Difference -----
    mean1 = np.mean(X1_scaled, axis=0)
    mean2 = np.mean(X2_scaled, axis=0)
    results.setdefault('Mean_Diff_L2', []).append(np.linalg.norm(mean1 - mean2))

    # ----- Filter Low-Variance Features Before Skew -----
    var_mask1 = X1_scaled.std(axis=0) > 1e-5
    var_mask2 = X2_scaled.std(axis=0) > 1e-5
    valid_mask = var_mask1 & var_mask2
    X1_filt = X1_scaled[:, valid_mask]
    X2_filt = X2_scaled[:, valid_mask]

    # ----- Skewness Difference -----
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        skew1 = skew(X1_filt, axis=0)
        skew2 = skew(X2_filt, axis=0)

    results.setdefault('Skewness_Diff_L2', []).append(np.linalg.norm(skew1 - skew2))

    # ----- PCA Cosine Similarity -----
    min_pca_components = min(n_pca_components, X1_scaled.shape[1], X2_scaled.shape[1])
    pca1 = PCA(n_components=min_pca_components).fit(X1_scaled)
    pca2 = PCA(n_components=min_pca_components).fit(X2_scaled)
    cos_sims = [np.dot(p1, p2) for p1, p2 in zip(pca1.components_, pca2.components_)]
    results.setdefault('PCA_Top_CosSim_Mean', []).append(np.mean(np.abs(cos_sims)))


    # ---- Cosine Similarity --------
    # Compute cosine similarity matrix between each sample in X1 and X2
    cos_sim_matrix = cosine_similarity(X1_scaled, X2_scaled)
    # average over all pairwise similarities
    avg_cos_sim = np.mean(cos_sim_matrix)
    # print(f"Average cosine similarity between datasets: {avg_cos_sim:.4f}")
    results.setdefault('Cosine_Similarity', []).append(avg_cos_sim)


def compute_pairwise_task_features(task_info_df, pair_results, dataset, task_len, variance_dict, std_dev_dict, Single_res_dict):

    """
    Computes task pairwise features: distance differences, variance/std differences, and loss changes.

    Args:
        task_info_df (pd.DataFrame): dataframe containing task average distances
        pair_results (pd.DataFrame): dataframe containing task pairs
        dataset (str): 'School', 'Landmine', 'Chemical'
        task_len (dict): task -> dataset size
        variance_dict (dict): task -> variance in target
        std_dev_dict (dict): task -> std deviation in target
        Single_res_dict (dict): task -> loss of STL model

    Returns:
        pd.DataFrame: Pairwise features dataframe
    """

    # Pick the correct average distance column
    if dataset == 'Chemical':
        avg_dist_col = 'Average_Hamming_Distance_within_Task'
    else:
        avg_dist_col = 'Average_Euclidean_Distance_within_Task'

    distance_dict = dict(zip(task_info_df['Task_Name'], task_info_df[avg_dist_col]))
    # print(f'distance_dict = {distance_dict}')

    results = {
        'Task1': [],
        'Task2': [],

        'DatasetSize_Ratio_t1': [],
        'DatasetSize_Ratio_t2': [],
        'Intra_Distance_t1': [],
        'Intra_Distance_t2': [],

        'Total_Dataset_Size': [],
        'DatasetSize_Diff': [],


        'Variance_Diff': [],
        'Avg_Variance': [],
        'Unified_Variance': [],
        'Unified_Variance_over_Sum': [],
        'Unified_Variance_over_Prod': [],

        'StdDev_Diff': [],
        'Avg_StdDev': [],
        'Unified_StdDev': [],
        'Unified_StdDev_over_Sum': [],
        'Unified_StdDev_over_Prod': [],

        'Distance_Diff': [],
        'Unified_Distance': [],

        'Distance_Diff_over_Sum': [],
        'Distance_Diff_over_Prod': [],
        'Unified_Dist_over_Sum': [],
        'Unified_Dist_over_Prod': [],
        'Energy_Distance': [],

        'Rank_based_Similarity': [],
        'Graph_based_Similarity': [],
        'Cosine_Similarity': [],

        'Single_Task_Loss': [],
        'Change_in_Loss': [],

        'Mean_Diff_L2' : [],
        'Skewness_Diff_L2' : [],
        'PCA_Top_CosSim_Mean': [],
        # 'Avg_JS_Divergence':[]
    }
    Pairs = [ast.literal_eval(p) for p in pair_results.Task_group]
    rand_order = True
    for group in tqdm.tqdm(range(len(Pairs))):

        tasks = Pairs[group]
        tasks = [int(tasks[0]), int(tasks[1])]

        d1, d2 = distance_dict[tasks[0]], distance_dict[tasks[1]]
        var1, var2 = variance_dict[tasks[0]], variance_dict[tasks[1]]
        std1, std2 = std_dev_dict[tasks[0]], std_dev_dict[tasks[1]]
        len1, len2 = task_len[tasks[0]], task_len[tasks[1]]
        loss1, loss2 = Single_res_dict[tasks[0]], Single_res_dict[tasks[1]]

        # Avoid division by zero
        if (d1 + d2 == 0) or (var1 + var2 == 0) or (std1 + std2 == 0):
            continue

        # Compute sample size, variance and stddev averages
        avg_var = np.mean([var1, var2])
        avg_stddev = np.mean([std1, std2])


        unified_distance,unified_variance,unified_stddev = get_unified_distance_labelVariance(tasks)
        geometric_mean_dist = np.sqrt(d1 * d2)
        geometric_mean_var = np.sqrt(var1 * var2)
        geometric_mean_std = np.sqrt(std1 * std2)
        # print(f'unified_distance = {unified_distance}')
        X1 = load_task_features(tasks[0], dataset)
        X2 = load_task_features(tasks[1], dataset)

        X1 = np.array(X1)
        X2 = np.array(X2)
        X1 = standard_scaler.transform(X1)
        X2 = standard_scaler.transform(X2)

        sim = rank_based_similarity(X1,X2, agg_method='mean')



        k_val = int(0.05 * (len(X1) + len(X2)))  # 5% of total samples

        score = graph_based_similarity(X1, X2, k=k_val)

        sum_loss_single_task = loss1 + loss2
        group_total_loss = pair_results.Total_Loss[group]


        '''get energy distance'''
        # combined_data = np.vstack((X1, X2))



        metric = 'hamming' if dataset == 'Chemical' else 'euclidean'
        energy_dist = fast_energy_distance(X1, X2, metric=metric)

        # Fill results
        results['Task1'].append(tasks[0])
        results['Task2'].append(tasks[1])

        '''task-specific features'''
        import random

        # print(rand_order)
        # if rand_order:
        results['DatasetSize_Ratio_t1'].append(len1/ (len1 + len2))
        results['DatasetSize_Ratio_t2'].append(len2/ (len1 + len2))
        results['Intra_Distance_t1'].append(d1/(d1+d2))
        results['Intra_Distance_t2'].append(d2/(d1+d2))


        results['DatasetSize_Diff'].append(abs(len1 - len2) / (len1 + len2))

        '''Distance Features'''
        results['Distance_Diff_over_Sum'].append(abs(d1 - d2) / (d1 + d2))
        results['Distance_Diff_over_Prod'].append(abs(d1 - d2) / geometric_mean_dist)

        results['Distance_Diff'].append(abs(d1 - d2))
        results['Unified_Distance'].append(unified_distance)
        results['Unified_Dist_over_Sum'].append(unified_distance / (d1 + d2))
        results['Unified_Dist_over_Prod'].append(unified_distance / geometric_mean_dist)


        '''Label Features'''
        results['Unified_Variance_over_Sum'].append(unified_variance / (var1 + var2))
        results['Unified_Variance_over_Prod'].append((unified_variance**2) / (var1 * var2))

        results['Unified_StdDev_over_Sum'].append(unified_stddev / (std1 + std2))
        results['Unified_StdDev_over_Prod'].append((unified_stddev**2) / (std1 * std2))

        results['Variance_Diff'].append(abs(var1 - var2))
        results['Unified_Variance'].append(unified_variance)
        results['StdDev_Diff'].append(abs(std1 - std2))
        results['Unified_StdDev'].append(unified_stddev)

        results['Avg_Variance'].append(avg_var)
        results['Avg_StdDev'].append(avg_stddev)
        results['Total_Dataset_Size'].append(len1 + len2)

        results['Rank_based_Similarity'].append(sim)
        results['Graph_based_Similarity'].append(score)


        results['Energy_Distance'].append(energy_dist)


        results['Single_Task_Loss'].append(sum_loss_single_task)
        results['Change_in_Loss'].append((sum_loss_single_task - group_total_loss) / sum_loss_single_task)

        compute_additional_pairwise_features(X1, X2,results)
        rand_order = not rand_order

    print(len(results))
    # for k, v in results.items():
    #     print(f'key: {k}, value: {len(v)}')
    for key in results.keys():
        print(len(results[key]), key)
    return pd.DataFrame(results)




def compute_interTask_distance_fast(DataSet_Scaled, dataset, max_pairs=10000):
    n_samples = len(DataSet_Scaled)

    if n_samples <= 3000:
        # 🚀 Use pdist for full pairwise distances
        if dataset in ['School', 'Landmine', 'Parkinsons']:
            dists = distance.pdist(DataSet_Scaled, metric='euclidean')
        elif dataset == 'Chemical':
            dists = distance.pdist(DataSet_Scaled, metric='hamming')
        # print(f'len(dists) = {len(dists)}')
        return np.mean(dists)

    else:
        # 🚀 For large datasets, sample random pairs
        idx = np.random.choice(n_samples, size=(max_pairs, 2), replace=True)
        if dataset in ['School', 'Landmine', 'Parkinsons']:
            d = np.linalg.norm(DataSet_Scaled[idx[:, 0]] - DataSet_Scaled[idx[:, 1]], axis=1)
        elif dataset == 'Chemical':
            d = np.mean(DataSet_Scaled[idx[:, 0]] != DataSet_Scaled[idx[:, 1]], axis=1)
        return np.mean(d)


def get_task_information():
    'Get each tasks data size, variance, std dev in target, distance among samples, etc '
    if dataset == 'Chemical':
        dist_type = 'Hamming'
    else:
        dist_type = 'Euclidean'

    Task_Name = []
    length = []
    Variance = []
    Std_Dev = []


    Average_Distance_within_Task_after_Scaling = []
    Combined_Dataset = []
    for task_id in TASKS:
        if dataset == 'School':
            csv = (f"{DataPath}/{task_id}_School_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df = df[['1985', '1986', '1987',
                     'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                     'Turkish', 'Other',
                     'VR_Band', 'Gender',
                     'FSM', 'VR_BAND_Student', 'School_Gender', 'Maintained', 'Church', 'Roman_Cath',
                     'ExamScore']]
        if dataset == 'Landmine':
            csv = (f"{DataPath}LandmineData_{task_id}.csv")
            df = pd.read_csv(csv, low_memory=False)
        if dataset == 'Chemical':
            csv = (f"{DataPath}{task_id}_Molecule_Data.csv")
            df = pd.read_csv(csv, low_memory=False)

        Combined_Dataset.append(df)

    '''Data for all tasks combined'''
    Combined_Dataset = pd.concat(Combined_Dataset)
    Combined_Dataset = Combined_Dataset.dropna()
    # '''removing the labels'''
    # # if dataset == 'School':
    # #     X_All = Combined_Dataset.drop(columns=['ExamScore']).values
    # # elif dataset == 'Landmine':
    # #     X_All = Combined_Dataset.drop(columns=['Labels']).values
    # # elif dataset == 'Chemical':
    # #     X_All = Combined_Dataset.drop(columns=['181']).values
    X_All = Combined_Dataset.values
    print(f'Shape of X is {X_All.shape}')
    standard_scaler = preprocessing.QuantileTransformer(n_quantiles=30)
    standard_scaler.fit(X_All)

    for task_id in TASKS:
        if dataset == 'School':
            csv = (f"{DataPath}/{task_id}_School_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df = df[['1985', '1986', '1987',
                     'ESWI', 'African', 'Arab', 'Bangladeshi', 'Caribbean', 'Greek', 'Indian', 'Pakistani', 'SE_Asian',
                     'Turkish', 'Other',
                     'VR_Band', 'Gender',
                     'FSM', 'VR_BAND_Student', 'School_Gender', 'Maintained', 'Church', 'Roman_Cath',
                     'ExamScore']]
            Variance.append(df.ExamScore.var(ddof=0))
            Std_Dev.append(df.ExamScore.std(ddof=0))
        if dataset == 'Landmine':
            csv = (f"{DataPath}LandmineData_{task_id}.csv")
            df = pd.read_csv(csv, low_memory=False)
            Variance.append(df.Labels.var(ddof=0))
            Std_Dev.append(df.Labels.std(ddof=0))

        if dataset == 'Chemical':
            csv = (f"{DataPath}{task_id}_Molecule_Data.csv")
            df = pd.read_csv(csv, low_memory=False)
            df.loc[df['181'] < 0, '181'] = 0
            Variance.append(df['181'].var(ddof=0))
            Std_Dev.append(df['181'].std(ddof=0))

        Task_Name.append(task_id)
        length.append(len(df))

        x = df.values
        DataSet_Scaled = standard_scaler.transform(x)

        average_distance = compute_interTask_distance_fast(DataSet_Scaled, dataset)
        Average_Distance_within_Task_after_Scaling.append(average_distance)

    print(f'dataset: {dataset}, avg distance: {np.mean(Average_Distance_within_Task_after_Scaling):0.3f} ± std dev: {np.std(Average_Distance_within_Task_after_Scaling):0.3f}')
    Task_Information = pd.DataFrame({'Task_Name': Task_Name,
                                     'Dataset_Size': length,
                                     'Variance': Variance,
                                     'Std_Dev': Std_Dev,
                                     f'Average_{dist_type}_Distance_within_Task': Average_Distance_within_Task_after_Scaling})
    Task_Information.to_csv(f'{DataPath}Task_Information_w_Distance_{dataset}.csv', index=False)
    return length






TASKS_DICT = {'School': [i for i in range(1, 140)],
              'Landmine': [i for i in range(0, 29)],
              'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                           57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
              'Parkinsons': [i for i in range(1, 43)],}

AVG = False
for dataset in ['School', 'Chemical', 'Landmine',]:

    DataPath = f"../Dataset/{dataset.upper()}/"
    if dataset == 'School':
        Task_InfoData = pd.read_csv(f'{DataPath}Task_Information_{dataset}.csv', low_memory=False)
        TASKS = list(Task_InfoData['Task_Name'])
    if dataset == 'Chemical':
        ChemicalData = pd.read_csv(f'{DataPath}Task_Information_{dataset}.csv', low_memory=False)
        TASKS = list(ChemicalData['Molecule'])
    if dataset == 'Landmine':
        Task_InfoData = pd.read_csv(f'{DataPath}Task_Information_{dataset}.csv', low_memory=False)
        TASKS = list(Task_InfoData['Task_Name'])

    print(f'dataset = {dataset} : TASKS = {TASKS}')

    length = get_task_information()
    print(f'dataset: {dataset}, maxlen: {max(length)}, minlen: {min(length)}, avg {np.mean(length):0.3f}  ± std dev: {np.std(length):0.3f}\n')


    task_info_df = pd.read_csv(f'{DataPath}Task_Information_w_Distance_{dataset}.csv')
    TASKS = TASKS_DICT[dataset]
    standard_scaler = normalize_based_on_all_tasks(TASKS)

    task_len = dict(zip(task_info_df['Task_Name'], task_info_df['Dataset_Size']))
    variance_dict = dict(zip(task_info_df['Task_Name'], task_info_df['Variance']))
    std_dev_dict = dict(zip(task_info_df['Task_Name'], task_info_df['Std_Dev']))

    if AVG:
        single_results = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        pair_results = pd.read_csv(f'../RESULTS/{dataset}_FIXED_PTL_Avg.csv')
    else:
        RUN = 1
        if dataset == 'Chemical':
            datapath = '../mtl_training/chem_results/'
            ARCH = 'Arch_1'
        if dataset == 'School':
            datapath = '../mtl_training/sch_results/'
            ARCH = 'Arch_1'
        if dataset == 'Landmine':
            datapath = '../mtl_training/landmine_results/'
            ARCH = 'Arch_1'
        single_results = pd.read_csv(f'{datapath}{dataset}_FIXED_STL_run_{RUN}_SGD_Arch_{ARCH}.csv')
        pair_results = pd.read_csv(f'{datapath}{dataset}_FIXED_pairs_run_{RUN}_SGD_Arch_{ARCH}.csv')

    Single_res_dict = dict(zip(single_results['TASKS'], single_results['Total_Loss']))  # or however you store STL loss

    features_df = compute_pairwise_task_features(task_info_df, pair_results, dataset, task_len, variance_dict, std_dev_dict, Single_res_dict)

    features_df.to_csv(f'{DataPath}Pairwise_Task_Features_{dataset}_FIXED.csv', index=False)
