import torch
import numpy as np
import pandas as pd
import random
import ast
import tqdm
pd.set_option('display.max_columns', None)

np.random.seed(2025)
random.seed(2025)
torch.manual_seed(2025)
torch.cuda.manual_seed(2025)


datasets = ['School','Landmine','Chemical'][2:]
TASKS_DICT = {'School': [i for i in range(1, 140)],
              'Landmine': [i for i in range(0, 29)],
              'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                           57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
              'Parkinsons': [i for i in range(1, 43)],}




for dataset in datasets:
    print(f'\n******* Dataset: {dataset} *******')
    # fig, ax = plt.subplots(1,3)
    RUNS = [1,2,3]
    for run in RUNS:
        TASKS = TASKS_DICT[dataset]
        for group_type in ['TRAIN', 'TEST','ALL'][2:]:

            if group_type == 'ALL':
                groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_GROUPS_run_{run}.csv')
                new_groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_FIXED_GroundTruth_NEW_run_{run}.csv')
                GROUPS_to_CONSIDER = list(groups['Task_group'])+list(new_groups['Task_group'])

            else:
                groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_GROUPS_run_{run}.csv')
                GROUPS_to_CONSIDER = list(groups['Task_group'])
            print(f'run {run}\tTotal {group_type} groups: {len(GROUPS_to_CONSIDER)}')
            print(new_groups.columns)


            if run == 'AVG':
                single_results = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
            else:
                single_results = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_run_{run}.csv')

            if group_type == 'ALL':
                Trained_Groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_GROUPS_run_{run}.csv')
                new_groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_FIXED_GroundTruth_NEW_run_{run}.csv')
                Trained_Groups = pd.concat([Trained_Groups, new_groups], ignore_index=True)
                '''only keep unique Task_groups'''
                Trained_Groups = Trained_Groups.drop_duplicates()
                print(len(Trained_Groups))
            else:
                Trained_Groups = pd.read_csv(f'../RESULTS/GROUPS_MTL/{dataset}_GROUPS_run_{run}.csv')

            tasks_map = []
            gain_collection = []
            loss_collection = []


            def gain_collection_vs_STL(task_list, result_from_mtl,gain_val = []):
                for task in task_list:
                    indi_stl_loss = single_results[single_results['TASKS'] == task]['Total_Loss'].values[0]
                    indi_mtl_loss = result_from_mtl[task]
                    gain_val.append((indi_stl_loss-indi_mtl_loss)/indi_stl_loss)
                return gain_val


            group_count = 0
            single_task_group = 0
            single_task_group_list = []
            pairwise_task_group_list = []
            multi_group_list = []

            tasks_index = {}

            # print(f'TASKS = {TASKS}')
            c = 0
            for t in TASKS:
                tasks_index.update({t:c})
                c+=1
            print(f'tasks_index = {tasks_index}')
            stl_loss_dict = {}
            for task in TASKS:
                stl_loss_dict[task] = single_results[single_results['TASKS'] == task]['Total_Loss'].values[0]



            Trained_Groups['Task_group'] = Trained_Groups['Task_group'].apply(lambda x: ast.literal_eval(x))
            Trained_Groups['Individual_Task_Score'] = Trained_Groups['Individual_Task_Score'].apply(lambda x: ast.literal_eval(x))
            '''sort task_group'''
            Trained_Groups['Task_group'] = Trained_Groups['Task_group'].apply(lambda x: tuple(sorted(x)))
            measure_gain = []
            not_found = 0


            for i in tqdm.tqdm(range(len(Trained_Groups))):
                task_group = Trained_Groups['Task_group'][i]
                task_score = Trained_Groups['Individual_Task_Score'][i]
                tasks_map.append([0 for i in range(len(TASKS))])
                gain_collection.append([0 for i in range(len(TASKS))])
                loss_collection.append([np.nan for i in range(len(TASKS))])
                task_score_new = {}
                for k,v in task_score.items():
                    if isinstance(k,str):
                        new_key = k.split('_')[1]
                        new_key = int(new_key)
                        task_score_new[new_key] = v
                    else:
                        task_score_new[k] = v
                task_score = task_score_new.copy()

                gain_val = gain_collection_vs_STL(task_group, task_score, [])

                group_count += 1
                measure_gain.append(sum(gain_val))

                for task in task_group:
                    tasks_map[-1][tasks_index[task]] = 1
                    gain_collection[-1][tasks_index[task]] = gain_val[task_group.index(task)]
                    loss_collection[-1][tasks_index[task]] = task_score[task]

            count = 0
            '''check for non-zzero in both task-map and gains'''
            for idx in range(len(tasks_map)):
                non_zero_idxs_map = np.nonzero(tasks_map[idx])[0]
                non_zero_idxs_map = [each for each in non_zero_idxs_map]
                non_zero_idx_gains = np.nonzero(gain_collection[idx])[0]
                non_zero_idx_gains = [each for each in non_zero_idx_gains]
                if len(non_zero_idxs_map) != len(non_zero_idx_gains):
                    count += 1
                    print(f'idx, {idx}, lengths: {len(non_zero_idxs_map), len(non_zero_idx_gains)}')
                    print(tasks_map[idx])
                    print(gain_collection[idx])
                    print(non_zero_idxs_map)
                    print(non_zero_idx_gains)
                else:
                    if np.equal(non_zero_idxs_map, non_zero_idx_gains).all():
                        pass
                    else:
                        print('mismatch')
            tasks_map = np.array(tasks_map)
            gain_collection = np.array(gain_collection)
            print(f'count: {count}')
            print(f'shape: {tasks_map.shape}, {gain_collection.shape}')
            '''only keep unique task-maps'''
            print(f'Before deduplication: {tasks_map.shape}, {gain_collection.shape}')

            # Find unique task_maps and the first index they occur at
            unique_task_maps, unique_indices = np.unique(tasks_map, axis=0, return_index=True)

            # Keep the corresponding gain_collection rows using those indices
            unique_gain_collection = gain_collection[unique_indices]

            print(f'After deduplication: {unique_task_maps.shape}, {unique_gain_collection.shape}')

            torch.save(unique_task_maps, f'PredData/RANDOMIZED_{dataset}_tasks_map_{group_type}_run_{run}_GroundTruth.pt')
            torch.save(unique_gain_collection, f'PredData/RANDOMIZED_{dataset}_gains_{group_type}_run_{run}_GroundTruth.pt')
