# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os.path

import warnings
warnings.filterwarnings("ignore")

# Plotting functions
from plotting_functions import cumulative_regret_plotting
from plotting_functions import average_plotting
from plotting_functions import cumulative_regret_plotting_no_ylimit

# Environment
from environment import problem_instance_type1

# Online fair division algorithms
from learners import ofd_linear
from learners import ofd_uniform
from learners import ofd_neural
from learners import ofd_gp


# ### Comparing algorithms ###
def compare_algorithms(algorithm_parameters, utility_type, rho_value, R, save_regret_data):    
    # Different algorithms
    cases       = ['OFD-UCB', 'OFD-TS', 'OFD-GP-UCB', 'OFD-GP-TS'] 
    total_cases = len(cases)
    
    # File to save or read
    agents  = len(algorithm_parameters[0][0])
    items   = len(algorithm_parameters[0]) - agents
    dim     = algorithm_parameters[2]
    file_name = F"gp_{utility_type}_{items}_{agents}_{dim}_{rho_value}_{R}"
    path_to_file = "results/data/{}.npz".format(file_name)
    
    if os.path.exists(path_to_file):
        # Loading existing regret data
        load_data = np.load(path_to_file)    
        algos_regret = load_data['algos_regret']
        algos_total_utility = load_data['algos_total_utility']
        algos_gini_coefficient = load_data['algos_gini_coefficient']
        algos_min_total_utility_ratio = load_data['algos_min_total_utility_ratio']
        algos_weights = load_data['algos_weights']
        
    else:
        algos_regret = []
        algos_total_utility = []
        algos_gini_coefficient = []
        algos_min_total_utility_ratio = []
        algos_weights = []
    
        for _ in tqdm(range(runs)):
            run_regret = []
            run_total_utility = []
            run_gini_coefficient = []
            run_min_total_utility_ratio = []
            run_weights = []
            
            # Shuffle the item-agents
            np.random.shuffle(algorithm_parameters[0])

            for c in range(total_cases):
                if cases[c] == 'OFD-Uniform':
                    iter_regret, all_stats = ofd_uniform(algorithm_parameters, utility_function=utility_type)
                
                elif cases[c] == 'OFD-Greedy':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='greedy', utility_function=utility_type)
                
                elif cases[c] == 'OFD-UCB':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='ucb', utility_function=utility_type)

                elif cases[c] == 'OFD-TS':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='ts', utility_function=utility_type)
                
                elif cases[c] == 'OFD-GP-Greedy':
                    iter_regret, all_stats = ofd_gp(algorithm_parameters, strategy='greedy', utility_function=utility_type)
                
                elif cases[c] == 'OFD-GP-UCB':
                    iter_regret, all_stats = ofd_gp(algorithm_parameters, strategy='ucb', utility_function=utility_type)
                
                elif cases[c] == 'OFD-GP-TS':
                    iter_regret, all_stats = ofd_gp(algorithm_parameters, strategy='ts', utility_function=utility_type)
                
                
                run_regret.append(iter_regret)
                run_total_utility.append(all_stats[0])
                run_gini_coefficient.append(all_stats[1])
                run_min_total_utility_ratio.append(all_stats[2])
                run_weights.append(all_stats[3])         

            algos_regret.append(run_regret)
            algos_total_utility.append(run_total_utility)
            algos_gini_coefficient.append(run_gini_coefficient)
            algos_min_total_utility_ratio.append(run_min_total_utility_ratio)
            algos_weights.append(run_weights)
            
            
        # Save regret data
        if save_regret_data:
            np.savez(path_to_file, 
                    algos_regret = algos_regret,
                    algos_total_utility = algos_total_utility,
                    algos_gini_coefficient = algos_gini_coefficient,
                    algos_min_total_utility_ratio = algos_min_total_utility_ratio,
                    algos_weights = algos_weights
                )        

    # ### Plotting Regret ###
    file_to_save = "results/plots/{}_regret.png".format(file_name)
    cumulative_regret_plotting(algos_regret, cases, file_to_save, 'upper right', y_lim=300)
    
    # ### Plotting Total Utility ###
    # file_to_save = "results/plots/{}_total_utility.png".format(file_name)
    # y_axis = "Total Utility"
    # average_plotting(algos_total_utility, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting Gini Coefficient ###
    file_to_save = "results/plots/{}_gini_coefficient.png".format(file_name)
    y_axis = "Gini Coefficient"
    average_plotting(algos_gini_coefficient, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting ratio between minimum utility and total utility ###
    file_to_save = "results/plots/{}_min_total_utility_ratio.png".format(file_name)
    y_axis = "Minimum Utility/Total Utility"
    average_plotting(algos_min_total_utility_ratio, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting l2-norm of weight difference ###
    # file_to_save = "results/plots/{}_min_total_utility_ratio.png".format(file_name)
    # y_axis = r"$||w - w^*||_2$"
    # average_plotting(algos_weights, cases, file_to_save, 'upper right', runs, y_label=y_axis)


# ########################### Bandit problem ###########################
# ### Problem Instance ###
d           = 2
items       = 500
item_copies = 1
agents      = 10
rho         = 0.85  # 0: Max-min <= rho <= 1: Efficiency
runs        = 20
utility_type = 'square'

# Initializing the problem instance
np.random.seed(0)
save_data = True
run_current_instance = False

if run_current_instance:
    algorithm_parameters = problem_instance_type1(d, items, item_copies, agents, rho)
    compare_algorithms(algorithm_parameters, utility_type, rho, runs, save_data)

else:
    # Type of problems
    algorithm_parameters = problem_instance_type1(d, items, item_copies, agents, rho)
    problems = ['square', 'cosine', 'sine'] # 'linear', 'square', 'cosine', 'sine'
    for p in problems:
        print(F"Running for problem: {p}")
        
        # Get the problem instance and compare different algorithms
        if p == 'linear':
           compare_algorithms(algorithm_parameters, p, rho, runs, save_data)
        
        elif p == 'square':
            compare_algorithms(algorithm_parameters, p, rho, runs, save_data)

        elif p == 'cosine':
            compare_algorithms(algorithm_parameters, p, rho, runs, save_data)
        
        elif p == 'sine':
            compare_algorithms(algorithm_parameters, p, rho, runs, save_data)
        
        
        

