import sys
import pandas as pd
import ast
import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns

def fix_result_col(res):
    Individual_Task_Score = []
    for index, row in res.iterrows():
        tmp = {}
        grp = ast.literal_eval(res['Task_group'][index])
        scr = ast.literal_eval(res['Individual_Task_Score'][index])
        # print(scr, grp)
        if isinstance(scr, list):
            for i, task in enumerate(grp):
                tmp[f'School_{task}'] = scr[i]
        else:
            for task in grp:
                tmp[f'School_{task}'] = scr[task]
        Individual_Task_Score.append(tmp)
    res = res.drop(columns=['Individual_Task_Score'])
    res['Individual_Task_Score'] = Individual_Task_Score

    return res

histogram, hist_ax = plt.subplots(3,1)
DATASET = ['School',  'Chemical','Landmine',]
TASKS_DICT = {
    'School': list(range(1, 140)),  # 1-based
    'Landmine': list(range(0, 29)), # 0-based
    'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                 57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
}

RUN = int(sys.argv[1])
for dataset in DATASET:
    if dataset == 'Chemical':
        loss_str = 'molecule'
    if dataset == 'School':
        loss_str = 'School'
    if dataset == 'Landmine':
        loss_str = 'Landmine'

    if dataset == 'Chemical':
        datapath = '../mtl_training/chem_results/'
        ARCH = 'Arch_1'
    if dataset == 'School':
        datapath = '../mtl_training/sch_results/'
        ARCH = 'Arch_1'
    if dataset == 'Landmine':
        datapath = '../mtl_training/landmine_results/'
        ARCH = 'Arch_1'


    single_results = pd.read_csv(f'{datapath}{dataset}_FIXED_STL_run_{RUN}_SGD_Arch_{ARCH}.csv')
    datapath = '../RESULTS/GROUPS_MTL/'
    pair_results = pd.read_csv(f'{datapath}{dataset}_FIXED_pairs_run_{RUN}_SGD_Arch_{ARCH}.csv')

    if dataset == 'School':
        single_results = fix_result_col(single_results)
        pair_results = fix_result_col(pair_results)

    single_results['Task_group'] = single_results['Task_group'].apply(ast.literal_eval)
    TASKS = []
    for each in single_results['Task_group']:
        TASKS.append(int(each[0]))
    single_results['TASKS'] = TASKS
    single_results = single_results.sort_values(by='TASKS', ascending=True)
    single_results.to_csv(f'../RESULTS/{dataset}_FIXED_STL_run_{RUN}.csv', index=False)
    pair_results.to_csv(f'../RESULTS/{dataset}_FIXED_PTL_run_{RUN}.csv', index=False)

    TASKS = list(single_results['TASKS'])
    print(f'TASKS = {TASKS}')
    TASKS = TASKS_DICT[dataset]

    pair_results['Total_Loss'] = pair_results['Total_Loss']

    Pairs = [ast.literal_eval(p) for p in pair_results.Task_group]

    Pairwise_accuracy_dict_indiv = {}
    Pairwise_loss_dict_indiv = {}
    mtl_loss = {}

    for i in range(len(pair_results)):
        p = pair_results.Task_group[i]
        p = ast.literal_eval(p)
        p = tuple(p)

        mtl_loss[p] = pair_results.Individual_Task_Score[i]
        if dataset == 'School':
            Pairwise_loss_dict_indiv.update({p: pair_results.Individual_Task_Score[i]})
        else:
            Pairwise_loss_dict_indiv.update({p: ast.literal_eval(pair_results.Individual_Task_Score[i])})

    print(len(Pairwise_loss_dict_indiv), len(Pairwise_accuracy_dict_indiv))

    '''convert into a matrix'''
    pairwise_tasks = copy.deepcopy(Pairs)

    pairwise_affinity_matrix = np.array([[0.0 for i in range(len(TASKS))] for j in range(len(TASKS))])
    pairwise_loss_matrix = np.array([[0.0 for i in range(len(TASKS))] for j in range(len(TASKS))])
    tasks_map = []
    gain_collection = []
    cont_neg = 0
    cont_pos = 0

    accuracy_gain = []
    loss_gain = []
    task_pair_list = []
    Pairwise_affinity_dict = {task: {task: 0 for task in TASKS} for task in TASKS}
    for task_pair in pairwise_tasks:
        task_pair_list.append(task_pair)
        tasks_map.append([0 for i in range(len(TASKS))])
        gain_collection.append([0 for i in range(len(TASKS))])

        p = tuple(task_pair)

        results_from_mtl = Pairwise_loss_dict_indiv[p]
        if isinstance(results_from_mtl, str):
            results_from_mtl = ast.literal_eval(results_from_mtl)
        task_pair = tuple([int(p) for p in p])


        pairwise_loss_matrix[TASKS.index(task_pair[0])][TASKS.index(task_pair[1])] = results_from_mtl[f'{loss_str}_{task_pair[0]}']
        pairwise_loss_matrix[TASKS.index(task_pair[1])][TASKS.index(task_pair[0])] = results_from_mtl[f'{loss_str}_{task_pair[1]}']


        gain_val = []

        for task in task_pair:

            indi_stl_loss = single_results[single_results['TASKS'] == int(task)]['Total_Loss'].values[0]
            indi_mtl_loss = results_from_mtl[f'{loss_str}_{task}']

            if (indi_stl_loss-indi_mtl_loss) / indi_stl_loss <0:
                cont_neg+=1
            elif (indi_stl_loss - indi_mtl_loss) / indi_stl_loss > 0:
                cont_pos+=1
            gain_val.append((indi_stl_loss-indi_mtl_loss) / indi_stl_loss) #-ve gain bad, +ve gain good



        # print(f'task_pair = {task_pair}, gain_val = {gain_val}')
        pairwise_affinity_matrix[TASKS.index(task_pair[0])][TASKS.index(task_pair[1])] = gain_val[1]
        pairwise_affinity_matrix[TASKS.index(task_pair[1])][TASKS.index(task_pair[0])] = gain_val[0]

        Pairwise_affinity_dict[task_pair[0]][task_pair[1]] = gain_val[1]
        Pairwise_affinity_dict[task_pair[1]][task_pair[0]] = gain_val[0]

    print(len(pairwise_affinity_matrix))
    print(np.shape(pairwise_affinity_matrix))
    print(f'dataset {dataset}, cont_neg = {cont_neg}, cont_pos = {cont_pos}')

    print(f'max = {np.max(pairwise_affinity_matrix)}, min = {np.min(pairwise_affinity_matrix)}')

    count = 0
    for (s, t), gains in Pairwise_loss_dict_indiv.items():
        stl_i = single_results[single_results['TASKS'] == int(s)]['Total_Loss'].values[0]
        stl_j = single_results[single_results['TASKS'] == int(t)]['Total_Loss'].values[0]
        mtl_i = gains[f'{loss_str}_{s}']
        mtl_j = gains[f'{loss_str}_{t}']
        i = TASKS.index(int(s))
        j = TASKS.index(int(t))
        assert pairwise_affinity_matrix[i, j]==(stl_j - mtl_j) / stl_j

    task_num = len(TASKS)
    '''save the matrix'''
    # pairwise_affinity_matrix_flatten = pairwise_affinity_matrix.flatten()
    diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)
    filtered_pairwise_affinity = np.delete(pairwise_affinity_matrix.flatten(), diagonal_indices)

    print(f'shape of filtered_pairwise_affinity: {np.shape(filtered_pairwise_affinity)}')
    pairwise_affinity_matrix = pd.DataFrame(pairwise_affinity_matrix)

    pairwise_affinity_matrix.columns = TASKS
    pairwise_affinity_matrix.to_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{RUN}_SGD_FIXED.csv', index=False)

    '''find outlier in filtered_pairwise_affinity'''
    sorted_pairwise_affinity = sorted(filtered_pairwise_affinity)

    '''plot histogram'''
    col_idx = DATASET.index(dataset)

    # Compute histogram data
    bin_number = 100
    if dataset == 'School':
        bin_number = 150

    # Now plot it using the original call (same as before)
    # hist_ax[col_idx].hist(filtered_pairwise_affinity,
    #                       bins=bin_number, alpha=0.5, color='skyblue',
    #                       edgecolor='black')
    # Clip extreme values at 1st and 99th percentiles
    lower, upper = np.percentile(filtered_pairwise_affinity, [0.5, 99.5])
    clipped_affinity = np.clip(filtered_pairwise_affinity, lower, upper)

    # Plot the histogram using clipped values
    hist_ax[col_idx].hist(clipped_affinity,
                          bins=bin_number, alpha=0.5, color='skyblue',
                          edgecolor='black')

    # hist_ax[col_idx].set_title(f'{dataset} pairwise affinity matrix - Histogram')
    hist_ax[col_idx].set_title(f'{dataset} Affinity Histogram (1st–99th pct)')
    counts, bin_edges = np.histogram(clipped_affinity, bins=bin_number)

    # Optional: Save histogram data to CSV
    hist_data_df = pd.DataFrame({
        'bin_start': bin_edges[:-1],
        'bin_end': bin_edges[1:],
        'count': counts
    })
    # hist_data_df.to_csv(f'{dataset}_pairwise_affinity_histogram_clipped.csv', index=False)

    '''check positive and negative affinity'''
    # Example: assuming filtered_pairwise_affinity is a NumPy array
    num_positive = np.sum(filtered_pairwise_affinity > 0)
    num_negative = np.sum(filtered_pairwise_affinity < 0)
    num_zero = np.sum(filtered_pairwise_affinity == 0)  # optional
    print(f'dataset {dataset}, pos_count = {num_positive}, neg_count = {num_negative}, zero = {num_zero}')

    if col_idx == 2:
        hist_ax[col_idx].set_xlabel('Affinity-Value')
    hist_ax[col_idx].set_ylabel('Frequency')


histogram.show()
plt.show()