#!/usr/bin/python

import sys,os
import numpy as np
import pandas as pd
import random

import shapleyValues as shapley

def update_weights(weights: np.ndarray, improvedObjs: np.array, impairedObjs: np.array, update_factor = 2) -> np.ndarray:
    """Modifies a vector of objective weights such that the weight of specific objectives is either increased or decreased.

    Args:
        weights (np.ndarray): The current objective weights.
        improvedObjs (np.array): The indices of the objective weights that should be increased.
        impairedObjs (np.array): The indices of the objective weights that should be reduced. Must be the same length as improvedObjs
        update_factor: Scalar value. Updated objective weights will be multiplied with / divided by this value. Should be > 1. Default: 2.

    Returns:
        np.ndarray: A matrix containing one copy of the original weight vector for each affected obj, with the weight at the corresponding index either increased or decreased.
    """
    assert len(improvedObjs) == len(impairedObjs)
    adjusted = np.tile(weights, [len(improvedObjs), 1])
    for i in np.arange(len(improvedObjs)):
        adjusted[i, improvedObjs[i]] *= update_factor
        adjusted[i, impairedObjs[i]] /= update_factor
    return adjusted

def normalized(a, ord=2, axis=-1):
    """Normalizes an array.

    Args:
        a (array_like): Input array to be normalized.
        ord (int, optional): The order of the norm. Defaults to 2 (Euclidean norm).
        axis (int, optional): The axis along which the array should be normalized. Defaults to -1.

    Returns:
        np.ndarray: The normalized array
    """
    l2 = np.linalg.norm(a, ord, axis, keepdims=True)
    l2[l2==0] = 1
    return a / l2

ALL_INITIAL_WEIGHT_POLICIES = {
    "ones": lambda o: np.ones(o),
    "uniform": lambda o: np.ones(o) / o,
    "random_exp": lambda o: np.array([10.0**random.randrange(0,5) for _ in range(o)]),
    "random_exp_norm": lambda o: normalized(np.array([10.0**random.randrange(0,5) for _ in range(o)]))
    }
ALL_UPDATE_STRATEGIES = ["A: Strengthen target, impair rival", "B: Strengthen target", "C: Strengthen target, impair random other", "D: Impair rival", "E: Impair random other"]

STATISTIC_LABELS = ["# improving", "# worsening", "# equal best", "# equal best if 0", "total absolute improvement"]

combined_statistic_lists = []

np.set_printoptions(formatter={'float': '{: 0.3f}'.format})    

# Experiment configuration
runs = 10
instances = range(1, 34)
initial_weights = ["ones", "random_exp"]
update_factors = [1.5, 2, 10]
strategies = range(len(ALL_UPDATE_STRATEGIES) - 1)

for run in range(runs):
    for instanceNr in instances:
        print(f"Instance {instanceNr}, run {run}")
        csv_in = os.path.join("pareto-fronts", f"instance{instanceNr}.csv")

        df = pd.read_csv(csv_in)
        df = df.drop('Hard conflicts', axis=1)
        
        objs = list(df.columns)
        nr_objs = len(objs)

        pareto_front = df.to_numpy()

        stats = []
        
        for iw in initial_weights:
            
            initial_weight_set = ALL_INITIAL_WEIGHT_POLICIES[iw](nr_objs)
            
            original_best = shapley.optimizer(initial_weight_set, pareto_front)
            
            # Perform calculation
            shapley_values = shapley.shapley_values(pareto_front, initial_weight_set)
            #print(shapley_values)
            
            most_improving_obj = np.argmin(shapley_values, axis=1)
            most_inhibiting_obj = np.argmax(shapley_values, axis=1)
            shapley_nondiag = np.copy(shapley_values)
            shapley_nondiag[np.eye(len(objs), dtype=bool)] = np.inf
            most_improving_other_obj = np.argmin(shapley_nondiag, axis=1)
            shapley_nondiag[np.eye(len(objs), dtype=bool)] = np.NINF
            most_inhibiting_other_obj = np.argmax(shapley_nondiag, axis=1) #Corresponds to rival in R-XIMO (worst inhibiting objective, unless target is itself worst, then second-worst)

            targets = np.arange(len(objs))
            rivals = most_inhibiting_other_obj

            for update_factor in update_factors:
                
                all_weight_updates_incr = update_weights(initial_weight_set, np.arange(len(objs)), [[]]*len(objs), update_factor)
                all_weight_updates_decr = update_weights(initial_weight_set, [[]]*len(objs), np.arange(len(objs)), update_factor)

                pairs = list(zip(*[(i,j) for i in range(len(objs)) for j in range(len(objs))]))
                all_weight_updates_pairs = update_weights(initial_weight_set, pairs[0], pairs[1], update_factor)
    
                def index_of_pair(improveIdx, impairIdx):
                    return len(objs) * improveIdx + impairIdx

                adjusted_best_incr = shapley.optimizer(all_weight_updates_incr, pareto_front)
                adjusted_best_decr = shapley.optimizer(all_weight_updates_decr, pareto_front)
                
                adjusted_best_pairs = shapley.optimizer(all_weight_updates_pairs, pareto_front)
                
                best_possible_obj_diff_single = (np.minimum(np.min(adjusted_best_incr, axis=0), np.min(adjusted_best_decr, axis=0))) - original_best
                best_possible_obj_diff_pair = np.min(adjusted_best_pairs, axis=0) - original_best
                best_possible_obj_diff = np.minimum(best_possible_obj_diff_pair, best_possible_obj_diff_single)
                
                for strategy in strategies:
                        
                    if strategy == 0: #A
                        adjusted_best = adjusted_best_pairs[index_of_pair(targets, rivals)]
                    elif strategy == 1: #B
                        adjusted_best = adjusted_best_incr[targets]            
                    elif strategy == 2: #C
                        others = np.array([random.choice(list(set(range(len(objs))) - set([targets[i], rivals[i]]))) for i in range(len(objs))])
                        adjusted_best = adjusted_best_pairs[index_of_pair(targets, others)]
                    elif strategy == 3: #D
                        adjusted_best = adjusted_best_decr[rivals]
                    elif strategy == 4: #E
                        others = np.array([random.choice(list(set(range(len(objs))) - set([targets[i], rivals[i]]))) for i in range(len(objs))])
                        adjusted_best = adjusted_best_decr[others]
                    else:
                        raise Exception(f"Strategy {strategy} not implemented")
                    
                    obj_diff = np.diagonal(adjusted_best) - original_best
                    
                    nr_improved = (obj_diff < 0).sum()
                    nr_worsened = (obj_diff > 0).sum()
                    nr_equal_best_improvement = (obj_diff == best_possible_obj_diff).sum()
                    nr_equal_best_improvement_of_0 = ((obj_diff == best_possible_obj_diff) & (obj_diff == 0)).sum()
                    total_improvement_absolute = obj_diff.sum()
                    # Total relative improvement is ill-defined due to objectives with value 0
                    
                    stats += [nr_improved, nr_worsened, nr_equal_best_improvement, nr_equal_best_improvement_of_0, total_improvement_absolute]
                        
        combined_statistic_lists.append(stats)

    idx = pd.MultiIndex.from_product([initial_weights, update_factors, np.take(ALL_UPDATE_STRATEGIES, strategies), STATISTIC_LABELS])
    combined_statistics = pd.DataFrame(combined_statistic_lists, columns=idx)

print(combined_statistics)

combined_statistics.to_csv("results.csv")
