import pandas as pd
import numpy as np

def prepare_fifty_ita_matrix(TASKS,ITA_file):
    '''read line by line'''
    gain_matrix = np.array([[0.0 for i in range(len(TASKS))] for j in range(len(TASKS))])

    gain_dict = {task:[] for task in TASKS}

    with open(ITA_file, 'r') as f:
        lines = f.readlines()


        for line in lines:
            line = line.split(',[')
            task_name = int(line[0])
            if task_name not in TASKS:
                continue
            task_idx = TASKS.index(task_name)

            all_epochs = line[1].split('}, {')


            tot_line = 0
            for each_epoch in all_epochs:
                each_epoch = each_epoch.split(',')

                task_no = 0


                for each_task in each_epoch:
                    each_task = each_task.replace('}]', '')
                    each_task = each_task.split(':')
                    gain_dict[TASKS[task_idx]].append(float(each_task[1]))
                    gain_matrix[task_idx][task_no] = gain_matrix[task_idx][task_no] + float(each_task[1])
                    task_no += 1

                tot_line = tot_line + 1

    gain_matrix_dict = pd.DataFrame.from_dict(gain_matrix)
    gain_matrix_dict.columns = TASKS
    return gain_matrix_dict


if __name__ == '__main__':

    DATASETS = ['School','Chemical','Landmine']

    for dataset in DATASETS:
        ds_idx = DATASETS.index(dataset)
        if dataset == 'Chemical':
            datapath = '../mtl_training/chem_results/'
        if dataset == 'School':
            datapath = '../mtl_training/sch_results/'
        if dataset == 'Landmine':
            datapath = '../mtl_training/landmine_results/'
        Arch = '_Arch_Arch_1'
        Tot_run = 6

        stl_res = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        TASKS = list(stl_res['TASKS'])  # ensure integer task IDs
        TASKS = [int(t) for t in TASKS]

        print(f'\n\n**********************dataset: {dataset}*******************************')
        print(f'Total tasks: {len(TASKS)}')


        '''get original ITA scores'''
        Method = 'ITA'
        avg_corr = []
        for run in range(4,1+Tot_run):
            # continue
            Avg_ITA_dict = {task: {task: 0.0 for task in TASKS} for task in TASKS}
            pos_fold = []
            neg_fold = []
            all_fold = []

            if Method == 'ITA':
                df_filename = f'{datapath}/ITA/gradient_metrics_ITA_run_{run}{Arch}_FIXED_ALL.csv'

            method_dict = prepare_fifty_ita_matrix(TASKS, df_filename)
            method_dict = np.array(method_dict)

            '''copy revised_integrals to Avg_ITA_dict'''
            for idx in range(0,len(TASKS)):
                for jdx in range(0,len(TASKS)):
                    Avg_ITA_dict[TASKS[idx]][TASKS[jdx]]+=method_dict[idx][jdx]

            pairwise_affinities = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{run}_SGD_FIXED.csv')
            pairwise_affinities = pairwise_affinities.to_numpy()
            '''remove the diagonal elements'''
            task_num = len(TASKS)
            diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)
            filtered_pairwise_affinity = np.delete(pairwise_affinities.flatten(), diagonal_indices)
            filtered_ITA = np.delete(method_dict.flatten(), diagonal_indices)
            corr = np.corrcoef(filtered_pairwise_affinity, filtered_ITA)[0][1]
            corr = round(corr, 5)

            method_dict = pd.DataFrame.from_dict(Avg_ITA_dict)
            method_dict.columns = TASKS
            if Method == 'ITA':
                method_dict.to_csv(f'{datapath}/ITA/{Method}_matrix_run_{run}_FIXED.csv', index=False)


        Run_times = []
        start_idx = 0

        ITA_Dict = {}
        for run in range(start_idx,1+Tot_run):
            if Method == 'ITA':
                method_file = pd.read_csv(f'{datapath}/ITA/{Method}_matrix_run_{run}_FIXED.csv')

            run_time_file = f'{datapath}/{dataset}_{Method}_time_run_{run}_SGD{Arch}_FIXED_ALL.txt'

            ITA_Dict[run] = method_file

            with open(run_time_file, 'r') as f:
                lines = f.readlines()
                Run_times.append(float(lines[0].strip()))

            '''avg over all runs'''
            if run == start_idx:
                avg_matrix = method_file
            else:
                avg_matrix = avg_matrix + method_file

            pairwise_affinities = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{run}_SGD_FIXED.csv')
            pairwise_affinities = pairwise_affinities.to_numpy()
            '''remove the diagonal elements'''
            task_num = len(TASKS)
            diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)
            filtered_pairwise_affinity = np.delete(pairwise_affinities.flatten(), diagonal_indices)
            filtered_ITA = np.delete(method_file.values.flatten(), diagonal_indices)
            print(
                f'run = {run}, correlation between pairwise affinity and ({Method}) = {np.corrcoef(filtered_pairwise_affinity, filtered_ITA)[0][1]:.3f}')
            corr = np.corrcoef(filtered_pairwise_affinity, filtered_ITA)[0][1]
            avg_corr.append(corr)
        print(f'Avg correlation: {np.mean(avg_corr):.4f} $\pm$ {np.std(avg_corr):.2f}')
        avg_matrix = avg_matrix/Tot_run
        avg_matrix_df = pd.DataFrame(avg_matrix)
        avg_matrix_df.columns = TASKS
        avg_matrix_df.to_csv(f'{datapath}/ITA/{Method}_matrix_avg.csv', index=False)
        print(f'\n****Avg run time for {Method} = {np.mean(Run_times)}****\n')
        print(f'last 3: {np.mean(Run_times)}')
        print(f'all run times: {Run_times}')
