import pickle
import sys
import numpy as np
import copy
import torch
import os
import pandas as pd


def approx_search_fixed(iteration, rest_gain, rest_mask, get_gain, pred_cost,
                        sample_num, search_num, ground_truth_idx, ground_truth_mask_idx):
    # Ensure NumPy
    rest_gain = np.array(rest_gain)
    rest_mask = np.array(rest_mask)
    ground_truth_idx = np.array(ground_truth_idx)
    ground_truth_mask_idx = np.array(ground_truth_mask_idx)

    total_list = [[] for _ in range(search_num)]
    total_mask_list = [[] for _ in range(search_num)]

    for j in range(iteration):
        initial_list = rest_gain.copy()
        initial_mask = rest_mask.copy()

        if j == 0:
            for m in range(search_num):
                idx = np.argmax(np.sum(initial_list, axis=1))
                total_list[m] = initial_list[idx][None, :]
                total_mask_list[m] = initial_mask[idx][None, :]
                initial_list = np.delete(initial_list, idx, axis=0)
                initial_mask = np.delete(initial_mask, idx, axis=0)
        else:
            for group in range(search_num):
                best_gain = np.max(total_list[group], axis=0) if total_list[group].ndim > 1 else total_list[group]
                gain_matrix = np.clip(rest_gain - best_gain, a_min=0, a_max=None)
                gain_sums = np.sum(gain_matrix, axis=1)
                idx = np.argmax(gain_sums)
                if gain_sums[idx] > 0:
                    gain_to_add = rest_gain[idx][None, :]
                    mask_to_add = rest_mask[idx][None, :]

                    total_list[group] = np.vstack([total_list[group], gain_to_add]) if total_list[
                                                                                           group].ndim > 1 else np.concatenate(
                        [total_list[group][None, :], gain_to_add])
                    total_mask_list[group] = np.vstack([total_mask_list[group], mask_to_add]) if total_mask_list[
                                                                                                     group].ndim > 1 else np.concatenate(
                        [total_mask_list[group][None, :], mask_to_add])

    # Final selection
    total_sums = [np.sum(np.max(group, axis=0)) for group in total_list]
    best_idx = int(np.argmax(total_sums))

    select_list = np.array(total_list[best_idx])
    select_mask = np.array(total_mask_list[best_idx])

    # Match with ground truth
    index_list = [
        i for row in select_mask
        for i in range(len(ground_truth_mask_idx))
        if np.array_equal(ground_truth_mask_idx[i], row)
    ]

    if index_list:
        selected_gr = ground_truth_idx[index_list]
        total_gain = sum(selected_gr[np.argmax(select_list[:, i])][i] for i in range(select_list.shape[1]))
    else:
        fallback_idx = next(
            i for i in range(len(ground_truth_mask_idx)) if np.array_equal(ground_truth_mask_idx[i], select_mask))
        selected_gr = ground_truth_idx[fallback_idx]
        total_gain = np.sum(selected_gr)

    return float(total_gain), selected_gr


def approx_optimal_fixed(sample_num, search_num, final_predictions, ground_truth,
                         ground_truth_mask, dataset, origtesty, origtestx, mask=None):
    final_predictions = np.array(final_predictions)
    ground_truth = np.array(ground_truth)
    ground_truth_mask = np.array(ground_truth_mask)

    if final_predictions.ndim == 3 or mask is not None:
        total_gain_traj, total_select = [], []
        for i in range(len(final_predictions)):
            pred_cost = np.copy(final_predictions[i])
            pred_mask = ground_truth_mask[i]
            pred_cost[pred_mask == 0] = -999
            iterations = pred_cost.shape[1]
            if dataset == 'School':
                iterations = iterations // 2
            gains, selections = [], []
            for it in range(1, iterations + 1):
                gain, sel = approx_search_fixed(it, pred_cost, pred_mask, pred_cost, pred_cost,
                                                sample_num, search_num,
                                                ground_truth[i], ground_truth_mask[i])
                gains.append(gain)
                selections.append(sel)
            total_gain_traj.append(gains)
            total_select.append(selections)
        return total_gain_traj, total_select
    else:
        pred_cost = np.copy(final_predictions)
        pred_cost[origtestx == 0] = -999
        total_gain_traj, total_select = [], []
        iterations = pred_cost.shape[1]
        for it in range(1, iterations + 1):
            gain, sel = approx_search_fixed(it, pred_cost, origtestx, pred_cost, pred_cost,
                                            sample_num, search_num, origtesty, origtestx)
            total_gain_traj.append(gain)
            total_select.append(sel)
        return total_gain_traj, total_select


if __name__ == '__main__':

    for ModelName in ['Quadratic', 'ITA', 'GRADTAE', 'LR']:
        if ModelName == 'LR':
            boosting_model = 'Oracle'
        else:
            boosting_model = ''

        for dataset in ['School','Chemical','Landmine']:
            boosting_folder = f'{ModelName}/'
            if not os.path.exists(f'RESULTS/GROUP_SELECTION/{boosting_folder}'):
                os.makedirs(f'RESULTS/GROUP_SELECTION/{boosting_folder}')
            Groups_Path = f'RESULTS/GROUP_SELECTION/{boosting_folder}'


            SEEDS = [1, 2, 3,]
            mtg_data_path = 'PredData/'
            for seed in SEEDS:
                GAIN_SEED = seed
                testx = torch.load(f'{mtg_data_path}RANDOMIZED_{dataset}_tasks_map_ALL_run_{GAIN_SEED}_GroundTruth.pt')
                testy = torch.load(f'{mtg_data_path}RANDOMIZED_{dataset}_gains_ALL_run_{GAIN_SEED}_GroundTruth.pt')
                testx, testy = torch.FloatTensor(testx), torch.FloatTensor(testy)
                ground_truth = testy

                label_dir = 'PredData/InitialPred/'
                tasks_map_file_name = f'{label_dir}{ModelName}_RANDOMIZED_{dataset}_tasks_map_ALL_run_{seed}_GroundTruth.pt'

                if boosting_model == 'Oracle':
                    pred_data_all_GT = testy
                else:
                    output_dir = 'PredData/InitialPred/'
                    pred_data_all_GT = torch.load(f'{output_dir}{ModelName}_RANDOMIZED_{dataset}_Predicted_Gains_ALL_run_{seed}_GroundTruth.pt')
                    print(f'shape of pred_data_all_GT: {pred_data_all_GT.size()}')

                tasks_map_all_GT = torch.load(tasks_map_file_name)

                task_map = tasks_map_all_GT
                final_predictions = pred_data_all_GT

                final_predictions = np.array(final_predictions)
                task_map = np.array(task_map)
                ground_truth = np.array(ground_truth)

                print(f'final_predictions shape = {np.shape(final_predictions)}')
                print(f'task_map shape = {np.shape(task_map)}')
                print(f'ground_truth shape = {np.shape(ground_truth)}')
                # exit(0)

                ground_truth_mask = copy.deepcopy(task_map)

                sample_num = len(ground_truth)
                search_num = 10

                total_gain_traj, total_select = approx_optimal_fixed(sample_num, search_num, final_predictions,
                                                                     ground_truth, ground_truth_mask, dataset, testy, testx)
                print(f'np.shape(total_gain_traj) = {np.shape(total_gain_traj)}')

                '''save gains'''

                if boosting_model == 'Oracle':
                    Groups_Path = f'RESULTS/GROUP_SELECTION/{boosting_model}'
                    saved_gains_filename = f'{Groups_Path}/{dataset}_ORACLE_GROUPINGSELECTION_avg_gains_seed_{seed}.pkl'
                else:
                    Groups_Path = f'RESULTS/GROUP_SELECTION/{ModelName}'
                    saved_gains_filename = f'{Groups_Path}/{dataset}_{ModelName}_GROUPINGSELECTION_avg_gains_seed_{seed}.pkl'
                with open(saved_gains_filename, 'wb') as f:
                    pickle.dump(total_gain_traj, f)
                print(f'saved {saved_gains_filename}')

import os
import csv
import pickle
import numpy as np
import matplotlib.pyplot as plt

TASKS_DICT = {'School': 139,
              'Chemical': 35,
              'Landmine': 29}

out_dir = "RESULTS/GROUP_SELECTION/"
os.makedirs(out_dir, exist_ok=True)

for dataset in TASKS_DICT.keys():
    num_tasks = TASKS_DICT[dataset]
    print(f'\nDATASET: {dataset}')
    Seeds = [1,2,3]
    fig, axes = plt.subplots(1, len(Seeds), figsize=(20, 4), sharey=True)

    for idx, seed in enumerate(Seeds):
        print(f'*******MTL Training for Test Groups = {seed}')
        if len(Seeds) > 1:
            ax = axes[idx]
        else:
            ax = axes
        ax.set_title(f"Seed {seed}")
        ax.set_xlabel("Group-Selection Budget (Number of Task Groups Selected)")
        ax.grid(True)
        ax.set_ylabel("Normalized Gain")

        # collect results for all models in a dict
        results = {"GroupCount": list(range(1, num_tasks+1))}

        for ModelName in ['Oracle', 'GRADTAE',  'ITA', 'Quadratic']:   # skip LR
            if ModelName == 'Oracle':
                grouping_Selection = f'RESULTS/GROUP_SELECTION/Oracle/{dataset}_ORACLE_GROUPINGSELECTION_avg_gains_seed_{seed}.pkl'
            else:
                Groups_Path = f'RESULTS/GROUP_SELECTION/{ModelName}'
                grouping_Selection = f'{Groups_Path}/{dataset}_{ModelName}_GROUPINGSELECTION_avg_gains_seed_{seed}.pkl'

            # Load gain trajectory
            with open(grouping_Selection, "rb") as fp:
                total_gain_traj = pickle.load(fp)
            # print(np.shape(total_gain_traj))
            gains = np.array(total_gain_traj)[:num_tasks] / num_tasks
            results[ModelName] = gains.tolist()

            # Plot
            ax.plot(results["GroupCount"], gains, label=f'{ModelName[:3]}:{np.max(gains):.4f}', marker='+', linestyle='--')
            ax.legend(title="Predictions")

        # === SAVE COMBINED DATA POINTS TO ONE CSV ===
        csv_path = os.path.join(out_dir, f"{dataset}_seed{seed}.csv")
        with open(csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=results.keys())
            writer.writeheader()
            for i in range(num_tasks):
                row = {key: results[key][i] if key != "GroupCount" else results[key][i]
                       for key in results.keys()}
                writer.writerow(row)
        print(f"[SAVED] {csv_path}")



    fig.suptitle(f'Total Gain vs Number of Selected Groups — {dataset}', fontsize=16)
    fig.tight_layout()
    fig.show()

plt.show()

datasets = ['School', 'Chemical', 'Landmine']
seeds = [3,1,1]
for dataset,seed in zip(datasets, seeds):
    print(f'Dataset: {dataset}')
    df = pd.read_csv(f"{out_dir}/{dataset}_seed{seed}.csv")
    Oracle_list = list(df['Oracle'])
    ITA_list = list(df['ITA'])
    GRADTAE_list = list(df['GRADTAE'])
    Quadratic_list = list(df['Quadratic'])
    max_oracle = round(max(Oracle_list), 4)
    max_ita = round(max(ITA_list), 4)
    max_gradtae = round(max(GRADTAE_list), 4)
    max_quadratic = round(max(Quadratic_list), 4)
    print(f'Method: Oracle, best coordinates: {(df["GroupCount"][np.argmax(Oracle_list)],max_oracle)}')
    print(f'Method: ITA, best coordinates: {(df["GroupCount"][np.argmax(ITA_list)],max_ita)}')
    print(f'Method: GRADTAE, best coordinates: {(df["GroupCount"][np.argmax(GRADTAE_list)],max_gradtae)}')
    print(f'Method: Ours, best coordinates: {df["GroupCount"][np.argmax(Quadratic_list)], max_quadratic}')
