# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os.path

import warnings
warnings.filterwarnings("ignore")

# Plotting functions
from plotting_functions import cumulative_regret_plotting
from plotting_functions import average_plotting
from plotting_functions import cumulative_regret_plotting_no_ylimit

# Environment
from environment import problem_instance_fixed

# Online fair division algorithms
from learners import ofd_linear
from learners import ofd_uniform

# Specific algorithms
from learners import ofd_efficient
from learners import ofd_fair


# ### Comparing algorithms ###
def compare_algorithms(algorithm_parameters, rho_value, R, save_regret_data, y_limit=100):    
    # Different algorithms
    cases       = ['OFD-Uniform', 'OFD-Greedy', 'OFD-UCB', 'OFD-TS', 'OFD-Eff-UCB', 'OFD-Eff-TS', 'OFD-Fair-UCB', 'OFD-Fair-TS'] 
    total_cases = len(cases)
    
    # File to save or read
    agents  = len(algorithm_parameters[0][0])
    items   = len(algorithm_parameters[0]) - agents
    dim     = algorithm_parameters[2]
    file_name = F"items_{items}_{agents}_{dim}_{rho_value}_{R}"
    path_to_file = "results/data/{}.npz".format(file_name)
    
    if os.path.exists(path_to_file):
        # Loading existing regret data
        load_data = np.load(path_to_file)    
        algos_regret = load_data['algos_regret']
        algos_total_utility = load_data['algos_total_utility']
        algos_gini_coefficient = load_data['algos_gini_coefficient']
        algos_min_total_utility_ratio = load_data['algos_min_total_utility_ratio']
        algos_weights = load_data['algos_weights']
        
    else:
        algos_regret = []
        algos_total_utility = []
        algos_gini_coefficient = []
        algos_min_total_utility_ratio = []
        algos_weights = []
    
        for _ in tqdm(range(runs)):
            run_regret = []
            run_total_utility = []
            run_gini_coefficient = []
            run_min_total_utility_ratio = []
            run_weights = []
            
            # Shuffle the item-agents
            np.random.shuffle(algorithm_parameters[0])

            for c in range(total_cases):
                if cases[c] == 'OFD-Uniform':
                    iter_regret, all_stats = ofd_uniform(algorithm_parameters)
                
                elif cases[c] == 'OFD-Greedy':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='greedy')
                    
                elif cases[c] == 'OFD-UCB':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='ucb')

                elif cases[c] == 'OFD-TS':
                    iter_regret, all_stats = ofd_linear(algorithm_parameters, strategy='ts')

                elif cases[c] == 'OFD-Eff-UCB':
                    iter_regret, all_stats = ofd_efficient(algorithm_parameters, strategy='ucb')

                elif cases[c] == 'OFD-Eff-TS':
                    iter_regret, all_stats = ofd_efficient(algorithm_parameters, strategy='ts')

                elif cases[c] == 'OFD-Fair-UCB':
                    iter_regret, all_stats = ofd_fair(algorithm_parameters, strategy='ucb')
                
                elif cases[c] == 'OFD-Fair-TS':
                    iter_regret, all_stats = ofd_fair(algorithm_parameters, strategy='ts')
                        
                run_regret.append(iter_regret)
                run_total_utility.append(all_stats[0])
                run_gini_coefficient.append(all_stats[1])
                run_min_total_utility_ratio.append(all_stats[2])
                run_weights.append(all_stats[3])

            algos_regret.append(run_regret)
            algos_total_utility.append(run_total_utility)
            algos_gini_coefficient.append(run_gini_coefficient)
            algos_min_total_utility_ratio.append(run_min_total_utility_ratio)
            algos_weights.append(run_weights)
            
            
        # Save regret data
        if save_regret_data:
            np.savez(path_to_file, 
                    algos_regret = algos_regret,
                    algos_total_utility = algos_total_utility,
                    algos_gini_coefficient = algos_gini_coefficient,
                    algos_min_total_utility_ratio = algos_min_total_utility_ratio,
                    algos_weights = algos_weights
                )        

    # ### Plotting Regret ###
    file_to_save = "results/plots/{}_regret.png".format(file_name)
    cumulative_regret_plotting_no_ylimit(algos_regret, cases, file_to_save, 'upper right')
    file_to_save = "results/plots/{}_regret_ylimit.png".format(file_name)
    cumulative_regret_plotting(algos_regret, cases, file_to_save, 'upper right',y_lim=y_limit)
    
    # ### Plotting Total Utility ###
    file_to_save = "results/plots/{}_total_utility.png".format(file_name)
    y_axis = "Total Utility"
    average_plotting(algos_total_utility, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting Gini Coefficient ###
    file_to_save = "results/plots/{}_gini_coefficient.png".format(file_name)
    y_axis = "Gini Coefficient"
    average_plotting(algos_gini_coefficient, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting ratio between minimum utility and total utility ###
    file_to_save = "results/plots/{}_min_total_utility_ratio.png".format(file_name)
    y_axis = "Minimum Utility/Total Utility"
    average_plotting(algos_min_total_utility_ratio, cases, file_to_save, 'upper right', runs, y_label=y_axis)
    
    # ### Plotting l2-norm of weight difference ###
    file_to_save = "results/plots/{}_weight_diff.png".format(file_name)
    y_axis = r"$||w - w^*||_2$"
    average_plotting(algos_weights, cases, file_to_save, 'upper right', runs, y_label=y_axis)


# ########################### Bandit problem ###########################
# ### Problem Instance ###
d           = 3
items       = 10000
item_copies = 1
agents      = 10
rho         = 0.85  # 0: Max-min <= rho <= 1: Efficiency
runs        = 20
y_limit_val = 200

# Initializing the problem instance
np.random.seed(0)
save_data = True
run_current_instance = True

if run_current_instance:
    algorithm_parameters = problem_instance_fixed(d, items, item_copies, agents, rho)
    compare_algorithms(algorithm_parameters, rho, runs, save_data, y_limit_val)

else:
    # Type of problems
    problems = ["easy", "medium", "hard", "harder"]

    for p in problems:
        print(F"Running for problem: {p}")
        
        # Get the problem instance
        if p == "easy":
            d = 2
            algorithm_parameters = problem_instance_fixed(d, items, item_copies, agents, rho)

        elif p == "medium":
            d = 3
            algorithm_parameters = problem_instance_fixed(d, items, item_copies, agents, rho)
            
        elif p == "hard":
            d = 5
            algorithm_parameters = problem_instance_fixed(d, items, item_copies, agents, rho)

        elif p == "harder":
            d = 10
            algorithm_parameters = problem_instance_fixed(d, items, item_copies, agents, rho)
        
        
        # Comparing different algorithms
        compare_algorithms(algorithm_parameters, rho, runs, save_data)

