import numpy as np
import pandas as pd
import copy
import torch
import random
import tqdm
from sklearn.metrics import r2_score

np.random.seed(2025)
random.seed(2025)
torch.manual_seed(2025)
torch.cuda.manual_seed(2025)

def gen_task_combinations_new(group_type, run):
    mtg_data_path = 'PredData/'

    testx = torch.load(f'{mtg_data_path}RANDOMIZED_{dataset}_tasks_map_{group_type}_run_{run}_GroundTruth{suffix}.pt')
    testy = torch.load(f'{mtg_data_path}RANDOMIZED_{dataset}_gains_{group_type}_run_{run}_GroundTruth{suffix}.pt')
    testx, testy = torch.FloatTensor(testx), torch.FloatTensor(testy)

    Only_groups = []
    for idx in range(0, len(testx)):
        active_tasks = testx[idx]  # Which tasks are active (1.0)?
        label_values = testy[idx]  # Ground-truth values
        # Find the active tasks
        active_task_indices = torch.where(active_tasks == 1.0)[0]  # Indices of active tasks
        active_task_names = [TASKS[i] for i in active_task_indices]  # Get their names

        if len(active_task_names) == 1:  # If a single task is active
            task = active_task_names[0]
            combination_name = task
        else:  # For combinations of tasks
            # combination_name = "|".join(active_task_names)  # Combine task names
            combination_name = "|".join(sorted(active_task_names))
        Only_groups.append(combination_name)


    combinations = [tuple(sorted(group.split('|'))) for group in Only_groups]
    print(f'len(combinations) = {len(combinations)}')

    rtn = {}
    for combi in tqdm.tqdm(combinations):
        task_grp = '|'.join(combi)
        if task_grp not in Only_groups:
            print(f'here: {task_grp}')
            continue
        rtn[task_grp] = {task: 0. for task in combi}

        for each_task in combi:
            for other_task in combi:
                if each_task != other_task:
                    rtn[task_grp][each_task] += revised_integrals[other_task][each_task] #B->A
    # print(f'len(rtn) = {len(rtn)}')
    for combi in rtn.keys():
        if '|' not in combi:
            task = combi
            rtn[combi][task] = -1e8
    missing = set(Only_groups) - set(rtn.keys())
    if len(missing) > 0:
        print("Missing groups:", missing)
        print("Count missing:", len(missing))
    return rtn


datasets = ['School','Chemical','Landmine',]
TASKS_DICT = {'School': [i for i in range(1, 140)],
              'Landmine': [i for i in range(0, 29)],
              'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                           57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
              'Parkinsons': [i for i in range(1, 43)],}



for dataset in datasets:
    print(f'\n******* Dataset: {dataset} *******')
    for ModelName in ['Quadratic','GRADTAE', 'ITA']:
        print(f'Model name: {ModelName}')
        TASKS = TASKS_DICT[dataset]
        TASKS = [str(i) for i in TASKS]

        AVG_CORR = []
        AVG_RSQ = []

        row_idx = datasets.index(dataset)


        RUNS = [1,2,3]
        if ModelName == 'GRADTAE':
            grad_RUNS = [0,1,3,4]
        for run in RUNS:
            if ModelName == 'Quadratic' or ModelName == 'LR':
                aff_run = 4 #random indicator of our pairwise prediction run
                pairwise_preds = np.load(f'../RESULTS/New_{ModelName}_{dataset}_Predicted_Pairwise_Affinity_{aff_run}_2025.npy')

            if ModelName == 'ITA':
                if dataset == 'Chemical':
                    datapath = '../mtl_training/chem_results/'
                if dataset == 'School':
                    datapath = '../mtl_training/sch_results/'
                if dataset == 'Landmine':
                    datapath = '../mtl_training/landmine_results/'
                ita_run = run
                pairwise_preds = pd.read_csv(f'{datapath}/ITA/{ModelName}_matrix_run_{ita_run}_FIXED.csv')
                pairwise_preds = np.array(pairwise_preds)
                np.fill_diagonal(pairwise_preds, 0.0)

            if ModelName == 'GRADTAE':
                M = 5
                grad_run = run+1
                print(f'grad_run: {grad_run}')
                pairwise_preds = pd.read_csv(f'../mtl_estimation/Results/Estimated_Task_Affinity_{dataset}_{grad_run}.csv')
                pairwise_preds = np.array(pairwise_preds)
                # ✅ set diagonal entries to 1
                np.fill_diagonal(pairwise_preds, 0.0)

            revised_integrals_ITA = {task: {task: 0. for task in TASKS} for task in TASKS}
            for i in range(len(TASKS)):
                for j in range(len(TASKS)):
                    revised_integrals_ITA[TASKS[i]][TASKS[j]] = pairwise_preds[i][j]

            revised_integrals = revised_integrals_ITA
            rtn = {}
            tasks = list(revised_integrals.keys())
            num_tasks = len(tasks)
            for group_type in ['ALL']:
                if group_type == 'ALL':
                    suffix = ''
                if group_type == 'TEST':
                    suffix = '_Special'
                rtn = {}
                tasks = list(revised_integrals.keys())
                num_tasks = len(tasks)
                task_combinations = gen_task_combinations_new(group_type= group_type, run=run)

                rtn = copy.deepcopy(task_combinations)

                for group in rtn:
                    if '|' in group:
                        for task in rtn[group]:
                            rtn[group][task] /= (len(group.split('|')) - 1)

                #
                # assert (len(rtn.keys()) == 2 ** len(revised_integrals.keys()) - 1)
                rtn_tup = [(key, val) for key, val in rtn.items()]

                # Mapping from task name to its index in TASKS
                tasks_index = {task: i for i, task in enumerate(TASKS)}

                tasks_map = []
                gain_collection = []

                for group_str, task_scores in rtn.items():
                    group = group_str.split('|')
                    if len(group) < 3:
                        print('group_str', group_str)
                        print(group)

                    # Initialize row for this group
                    task_vec = [0] * len(TASKS)
                    gain_vec = [0.0] * len(TASKS)

                    for task in group:
                        task_idx = tasks_index[task]
                        task_vec[task_idx] = 1
                        gain_vec[task_idx] = task_scores[task]

                    tasks_map.append(task_vec)
                    gain_collection.append(gain_vec)

                tasks_map_tensor = torch.FloatTensor(tasks_map)
                gain_tensor = torch.FloatTensor(gain_collection)

                # Save
                output_dir = 'PredData/InitialPred/'
                torch.save(tasks_map_tensor, f'{output_dir}{ModelName}_RANDOMIZED_{dataset}_tasks_map_{group_type}_run_{run}_GroundTruth{suffix}.pt')
                torch.save(gain_tensor, f'{output_dir}{ModelName}_RANDOMIZED_{dataset}_Predicted_Gains_{group_type}_run_{run}_GroundTruth{suffix}.pt')


                if group_type == 'ALL':
                    # Load true test gains again (already filtered)
                    data_dir = 'PredData/'
                    true_task_gains = torch.load(f'{data_dir}RANDOMIZED_{dataset}_gains_{group_type}_run_{run}_GroundTruth{suffix}.pt')

                    # print(f'Shape of true_task_gains: {true_task_gains.shape}, predicted gains: {gain_tensor.shape}')
                    assert true_task_gains.shape == gain_tensor.shape, "Mismatch in shape between true and predicted gains"

                    # Mask: consider only non-zero elements in true labels
                    mask = true_task_gains != 0

                    y_true = true_task_gains[mask]
                    y_pred = gain_tensor[mask]

                    test_rsq = r2_score(y_true, y_pred)
                    test_corr = np.corrcoef(y_true, y_pred)[0, 1]

                    print(f"✅ dataset: {dataset},  R²: {test_rsq:.4f}, Correlation: {test_corr:.4f}")
                    AVG_CORR.append(test_corr)
                    AVG_RSQ.append(test_rsq)

        print(f'✅{dataset.upper()} - AVG_RSQ: {np.mean(AVG_RSQ):0.4f} AVG_CORR: {np.mean(AVG_CORR):0.2f} $\pm$ {np.std(AVG_CORR):0.2f}')
