
import numpy as np
from ortools.linear_solver import pywraplp
from itertools import chain, combinations as combs
from tqdm import trange
from typing import List, Set, Dict, Tuple, Optional
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import wandb
import seaborn as sns
import pandas as pd
import pickle

## Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

# ## From Training folder
# import sys
# sys.path.append('../')
# from TrainModels import models
# from TrainModels import loaders

class MLPPowerIndex(nn.Module):
    '''
    MLP Archicture used for Banzhaf and Shapley
    '''
    def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob):
        super().__init__()
        layers = []
        for i in range(n_layers-1):
            layers += [
                nn.Linear(input_size, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(drop_prob)
            ]
            input_size = hidden_size

        # Add output layer
        layers += [
            nn.Linear(input_size, output_size),
            nn.Softmax(dim=1) # Normalize payoffs
            ]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)


def get_df_errors_var_powerindex(PARAMS, trained_models, testset_name):
    '''
    Get the df with errors for var-size models
    :param PARAMS: Dictionary with data configurations
    :param testset_name: Name of the test dataset
    :return: pandas DataFrame
    '''

    df_errors_var = pd.DataFrame(columns=['MAE_payoffs', 'RMSE_payoffs', 'n_players'])

    for model in trained_models.keys():

        #print(f'MLP : {model}')
        # Load dataset
        seq_len = trained_models[f'{model}']['seq_len']   
        with open(f'{PARAMS["test_path_var"]}{testset_name}.pickle', 'rb') as handle:
            test_dict = pickle.load(handle)
    
        # Select
        if PARAMS['sol_concept'] == 'Shapley':
            Y_ID = 'Y_shap'
        elif PARAMS['sol_concept'] == 'Banzhaf':
            Y_ID = 'Y_banz'
        else:
            raise NotImplementedError('Specify solution concept')

        test_set = PowerIndexDataset(test_dict['X'],
                                            test_dict[Y_ID]
                                            )
        test_loader = DataLoader(test_set, batch_size=1, shuffle=False)    
        
        # Load model from checkpoint
        model = load_pretrained_model_PI(config = trained_models[f'{model}'], 
                                        seq_len = seq_len,
                                        cp_path = PARAMS['checkpoint_path']
                                        )   
        # Count number of parameters
        #num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
        #print(f'Number of trainable parameters: {num_params} \n')

        # Test
        payoffs_actual = torch.zeros((len(test_loader), seq_len))
        payoffs_pred = torch.zeros_like(payoffs_actual)
        play_in_game = torch.zeros(len(test_loader))
        MAE_payoffs = torch.zeros_like(play_in_game)
        RMSE_payoffs = torch.zeros_like(play_in_game)

        with torch.no_grad():
            model.eval()
            for game, (x, payoffs) in enumerate(test_loader):

                payoffs_hat = model(x)
                payoffs_actual[game, :] = payoffs 
                payoffs_pred[game, :] = payoffs_hat

                play_index = (x.squeeze() != 0).nonzero()
                play_in_game[game] = play_index.shape[0]

                payoffs_norm_actual = payoffs[:, play_index]
                payoffs_norm_pred = payoffs_hat[:, play_index] / (payoffs_hat[:, play_index].sum())

                MAE_payoffs[game] = torch.abs(payoffs_norm_actual - payoffs_norm_pred).mean()
                RMSE_payoffs[game] = torch.sqrt(torch.pow((payoffs[:, play_index] - payoffs_hat[:, play_index]), 2).mean())

        # Store in df
        df_errors_var = df_errors_var.append(pd.DataFrame({ 'MAE_payoffs'   : MAE_payoffs,
                                                            'RMSE_payoffs'  : RMSE_payoffs,
                                                            'n_players'     : play_in_game,
                                                        }), ignore_index=True)
    # Convert objects to floats or int
    for col in df_errors_var.columns:
        if 'MAE' in col:
            df_errors_var[col] = df_errors_var[col].astype(float)

    return df_errors_var


def get_df_errors_var_leastcore(PARAMS, trained_models, testset_name):
    '''
    Get the df with errors for var-size models
    :param PARAMS: Dictionary with data configurations
    :param testset_name: Name of the test dataset
    :return: pandas DataFrame
    '''
    df_errors_var = pd.DataFrame(columns=['MAE_payoffs', 'MAE_eps', 'n_players'])

    for model in trained_models.keys():

        # Load dataset
        max_play = trained_models[f'{model}']['n_players']   
        with open(f'{PARAMS["test_path_var"]}{testset_name}.pickle', 'rb') as handle:
            test_dict = pickle.load(handle)
        test_set = EvalPayoffsDataset_noC(test_dict['X'], test_dict['sol_stack_lc'])
        # Make sure to set shuffle = False!
        test_loader = DataLoader(test_set, batch_size=1, shuffle=False)    
        
        # Load model from checkpoint
        model = load_pretrained_model(config = trained_models[f'{model}'], 
                                      input_size = max_play,
                                      cp_path = PARAMS['checkpoint_path']
                                    )   
        # Count number of parameters
        #num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
        #print(f'Number of trainable parameters: {num_params} \n')

        # Test
        payoffs_actual = torch.zeros((len(test_loader), max_play))
        payoffs_pred = torch.zeros_like(payoffs_actual)
        eps_actual = torch.zeros(len(test_loader))
        eps_pred = torch.zeros_like(eps_actual)
        play_in_game = torch.zeros_like(eps_actual)
        MAE_payoffs = torch.zeros_like(eps_actual)
        RMSE_payoffs = torch.zeros_like(eps_actual)
        MAE_eps = torch.zeros_like(eps_actual)
        RMSE_eps = torch.zeros_like(eps_actual)
        #count_feasible_sols = torch.zeros_like(eps_actual)
        payoffs_given_to_nonplayers = torch.zeros_like(eps_actual)

        with torch.no_grad():
            model.eval()
            for game, (x, sol) in enumerate(test_loader):

                # Separate epsilon and payoffs, get predictions
                payoffs, eps = sol[:, :-1], sol[:, -1]
                payoffs_hat, eps_hat = model(x)

                # Store 
                payoffs_actual[game, :], payoffs_pred[game, :] = payoffs, payoffs_hat
                eps_actual[game], eps_pred[game] = eps, eps_hat

                # Get indices where labels are not zero (i.e. where there are players)
                play_index = (x.squeeze() != 0).nonzero()
                non_play_index = (x.squeeze() == 0).nonzero()
                play_in_game[game] = play_index.shape[0]
                
                # Redistribute payoffs given to non players 
                payoffs_norm_actual = payoffs[:, play_index]
                payoffs_norm_pred = payoffs_hat[:, play_index] / (payoffs_hat[:, play_index].sum())

                # Compute errors
                MAE_payoffs[game] = torch.abs(payoffs_norm_actual - payoffs_norm_pred).mean()
                MAE_eps[game] = torch.absolute(eps - eps_hat)
                
                RMSE_payoffs[game] = torch.sqrt(torch.pow((payoffs_norm_actual - payoffs_norm_pred), 2).mean())
                RMSE_eps[game] = torch.sqrt(torch.pow((eps-eps_hat), 2))

                # Count the amount of payoffs that was allocated to the non-players
                payoffs_given_to_nonplayers[game] = payoffs_hat[:, non_play_index].sum().item()

                # Check if solution is feasible
                #count_feasible_sols[game] = is_feasible_payoff(coals_min_win, payoffs_hat, eps_hat, tol=2)

        df_errors_var = df_errors_var.append(pd.DataFrame({ 'Eps_actual'         : eps_actual,
                                                            'Eps_pred'           : eps_pred,
                                                            'MAE_payoffs'        : MAE_payoffs,
                                                            'MAE_eps'            : MAE_eps,
                                                            'RMSE_payoffs'       : RMSE_payoffs,
                                                            'RMSE_eps'           : RMSE_eps,
                                                           # 'is_feasible'        : count_feasible_sols,
                                                            'payoffs_to_nonplay' : payoffs_given_to_nonplayers,
                                                            'n_players'          : play_in_game,
                                                            }), ignore_index=True)
    # Convert objects to floats or int
    for col in df_errors_var.columns:
        if 'MAE' in col:
            df_errors_var[col] = df_errors_var[col].astype(float)

    return df_errors_var


def get_df_errors_fixed_powerindex(PARAMS, trained_models, testset_name):
    '''
    Get the df with errors for fixed-size models
    :param PARAMS: Dictionary with data configurations
    :param testset_name: Name of the test dataset
    :return: pandas DataFrame
    '''

    df_errors_fixed = pd.DataFrame(columns=['MAE_payoffs', 'RMSE_payoffs', 'n_players'])

    for model in trained_models.keys():

        # Load dataset
        n_players = trained_models[f'{model}']['n_players']   
        with open(f'{PARAMS["test_path_fixed"]}{n_players}players_test_mpi_{testset_name}.pickle', 'rb') as handle:
            test_dict = pickle.load(handle)
            
        # Select 
        if PARAMS['sol_concept'] == 'Shapley': Y_ID = 'Y_shap'
        elif PARAMS['sol_concept'] == 'Banzhaf': Y_ID = 'Y_banz'
        else: raise NotImplementedError('Specify solution concept')

        test_set = PowerIndexDataset(test_dict['X'], 
                                            test_dict[Y_ID]
                                            )
        test_loader = DataLoader(test_set, batch_size=1, shuffle=False)    
        
        # Load model from checkpoint
        model = load_pretrained_model_PI(config = trained_models[f'{model}'], 
                                         seq_len = n_players,
                                         cp_path = PARAMS['checkpoint_path']
                                        )   
        # Count number of parameters
        #num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
        #print(f'Number of trainable parameters: {num_params} \n')

        # Test
        payoffs_actual = torch.zeros((len(test_loader), n_players))
        payoffs_pred = torch.zeros_like(payoffs_actual)

        with torch.no_grad():
            model.eval()
            for game, (x, payoffs) in enumerate(test_loader):
                payoffs_hat = model(x)
                payoffs_actual[game, :] = payoffs 
                payoffs_pred[game, :] = payoffs_hat

        # Compute MAE and RMSE 
        MAE_payoffs = torch.abs(payoffs_actual - payoffs_pred).mean(axis=1)
        RMSE_payoffs = torch.sqrt(torch.pow((payoffs_actual - payoffs_pred), 2).mean(axis=1))
    
        # Store in df
        df_errors_fixed = df_errors_fixed.append(pd.DataFrame({ 'MAE_payoffs'   : MAE_payoffs,
                                                                'RMSE_payoffs'  : RMSE_payoffs,
                                                                'n_players'     : np.repeat(n_players, len(MAE_payoffs), axis=None),
                                                            }), ignore_index=True)                 
    # Convert objects to floats or int
    for col in df_errors_fixed.columns:
        if 'MAE' in col:
            df_errors_fixed[col] = df_errors_fixed[col].astype(float)

    
    return df_errors_fixed


def get_df_errors_fixed_leastcore(PARAMS, trained_models, testset_name):
    '''
    Get the df with errors for fixed-size models
    :param PARAMS: Dictionary with data configurations
    :param testset_name: Name of the test dataset
    :return: pandas DataFrame
    '''
    # Initialize dataframe
    df_errors_fixed = pd.DataFrame(columns=['MAE_payoffs', 'MAE_eps', 'RMSE_payoffs', 'RMSE_eps', 'Eps_actual', 'Eps_pred', 'is_feasible', 'n_players'])

    for i, model in enumerate(trained_models.keys()):

        #print(f'MLP : {model}')
        # Load dataset
        n_players = trained_models[f'{model}']['n_players']   
        with open(f'{PARAMS["test_path_fixed"]}{n_players}players_test_mpi_{testset_name}.pickle', 'rb') as handle:
            test_dict = pickle.load(handle)
        test_set = EvalPayoffsDataset_noC(test_dict['X'], test_dict['sol_stack_lc'])
        test_loader = DataLoader(test_set, batch_size=1, shuffle=False)    
        
        # Load model from checkpoint
        model = load_pretrained_model(config = trained_models[f'{model}'], 
                                     input_size = n_players,
                                     cp_path = PARAMS['checkpoint_path']
                                    )   
        # Count number of parameters
        num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
        #print(f'Number of trainable parameters: {num_params} \n')

        # Test
        payoffs_actual = torch.zeros((len(test_loader), n_players))
        payoffs_pred = torch.zeros_like(payoffs_actual)
        eps_actual = torch.zeros(len(test_loader))
        eps_pred = torch.zeros_like(eps_actual)
        count_feasible_sols = torch.zeros_like(eps_actual)

        with torch.no_grad():
            model.eval()
            for game, (x, sol) in enumerate(test_loader):

                # Separate epsilon and payoffs, get predictions
                payoffs, eps = sol[:, :-1], sol[:, -1]
                payoffs_hat, eps_hat = model(x)

                # Store 
                payoffs_actual[game, :], payoffs_pred[game, :] = payoffs, payoffs_hat
                eps_actual[game], eps_pred[game] = eps, eps_hat

                # Check if solution is feasible 
                #count_feasible_sols[i] = is_feasible_payoff(coals_min_win, payoffs_hat, eps_hat, tol=2)

        # Compute MAE divided by the number of players 
        RMSE_payoffs = torch.sqrt(torch.pow((payoffs_actual - payoffs_pred), 2).mean(axis=1))
        MAE_payoffs = torch.abs(payoffs_actual - payoffs_pred).mean(axis=1)
        MAE_eps = torch.abs(eps_actual - eps_pred)
        RMSE_eps = torch.sqrt(torch.pow((eps_actual - eps_pred), 2))

        df_errors_fixed = df_errors_fixed.append(pd.DataFrame({ 'Eps_actual'    : eps_actual,
                                                                'Eps_pred'      : eps_pred,
                                                                'MAE_payoffs'   : MAE_payoffs,
                                                                'MAE_eps'       : MAE_eps,
                                                                'RMSE_payoffs'  : RMSE_payoffs,
                                                                'RMSE_eps'      : RMSE_eps,
                                                                #'is_feasible'   : count_feasible_sols,
                                                                'n_players'     : np.repeat(n_players, len(MAE_payoffs), axis=None),
                                                            }), ignore_index=True)

    # Convert objects to floats or int
    for col in df_errors_fixed.columns:
        if 'MAE' in col:
            df_errors_fixed[col] = df_errors_fixed[col].astype(float)

    return df_errors_fixed

def load_pretrained_model_PI(config: dict, seq_len: int, cp_path: str):
    '''
    Initialize model from saved checkpoint
    :param config: sweep configurations, such as the hidden layer size
    :param seq_len: either the number of players (fixed size) of the max seq length
    :param cp_path: absolute path to .pth file
    :return model: model with weights after training
    '''

    model = MLPPowerIndex(input_size=seq_len, 
                        hidden_size=config['config']['hidden_size'], 
                        n_layers=config['config']['n_layers'], 
                        output_size=seq_len, 
                        drop_prob=0 #doesn't matter for eval
                    )
    # Load checkpoint
    checkpoint = torch.load(cp_path + config['sweep_name'] + '.pth', 
                            map_location=torch.device('cpu'))

    # Load trained parameters into model
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

class MLPLeastcore(nn.Module):
    '''
    MLP used for the least core
    '''
    def __init__(self, input_size, hidden_size, out_size_payoffs, drop_prob):
        super().__init__()

        self.lin1 = nn.Linear(input_size, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.out_payoffs = nn.Linear(hidden_size, out_size_payoffs)
        self.out_eps = nn.Linear(hidden_size, 1)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.dropout(x)
        x = F.relu(self.lin2(x))
        # Separate payoff vector and epsilon
        x_payoffs = self.out_payoffs(x)
        x_eps = self.out_eps(x)
        return self.softmax(x_payoffs), self.sigmoid(x_eps)
 

def load_pretrained_model(config: dict, input_size: int, cp_path: str):
    '''
    Initialize model from saved checkpoint
    :param config: sweep configurations, such as the hidden layer size
    :param cp_path: absolute path to .pth file
    :return model: model with weights after training
    '''

    model = MLPLeastcore(input_size = input_size,
                                hidden_size = config['config']['hidden_size'],
                                out_size_payoffs = config['n_players'],
                                drop_prob = 0 #doesn't matter for eval,
                                )
    # Load checkpoint
    checkpoint = torch.load(cp_path + config['sweep_name'] + '.pth', 
                            map_location=torch.device('cpu'))

    # Load trained parameters into model
    model.load_state_dict(checkpoint['model_state_dict'])
    return model


class EvalPayoffsDataset_noC(Dataset):
    ''' 
    Dataset for Least core payoffs (only works with batch size = 1)
    '''
    def __init__(self, 
                 X_stack: torch.Tensor, 
                 Y_stack: torch.Tensor, 
                 ):
        self.X_stack = X_stack
        self.Y_stack = Y_stack

    def __len__(self):
        return self.X_stack.shape[0]
    
    def __getitem__(self, index):

        return ( 
            self.X_stack[index, :].float(), 
            self.Y_stack[index, :].float(),
        )
        





class EvalPayoffsDataset(Dataset):
    ''' 
    Dataset for Least core payoffs (only works with batch size = 1)
    '''
    def __init__(self, 
                 X_stack: torch.Tensor, 
                 Y_stack: torch.Tensor, 
                 cset_min_win: List
                 ):
        self.X_stack = X_stack
        self.Y_stack = Y_stack
        self.cset_min_win = cset_min_win

    def __len__(self):
        return self.X_stack.shape[0]
    
    def __getitem__(self, index):

        # Convert set of coalitions to one hot vectors of equal length
        c_min_win_onehot = torch.zeros(size=(len(self.cset_min_win[index]), self.X_stack.shape[1]))
        # Convert coalition tuples to one hot encoding tensors
        for i, coal in enumerate(self.cset_min_win[index]):
            c_min_win_onehot[i, coal] = 1 
        return ( 
            self.X_stack[index, :].float(), 
            self.Y_stack[index, :].float(),
            c_min_win_onehot
        )
        
class PowerIndexDataset(Dataset):
    ''' Custom dataset class to work with Shapley values or Banzhaf indices 
    X, Y in shape (n_samples, n_players)
    '''
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        ''' get total number of samples in dataset '''
        return self.X.shape[0]
    
    def __getitem__(self, index):
        ''' get 1D tensor of weights and respective payoffs'''
        return ( 
            self.X[index, :].float(), 
            self.Y[index, :].float()
        )

def plot_example_preds(Y_actual: torch.Tensor, Y_preds: torch.Tensor, fname: str, n_samples: int = 8):
    ''' Plot a number of samples in the test set
    :param Y_actual: torch.Tensor with actual solutions (n_players, n_samples)
    :param Y_preds: torch.Tensor with model predictions (n_players, n_samples)
    :param fname: filename to print in the title
    :param n_samples: number of samples to display
    ''' 
    n_players = Y_actual.shape[0]-1

    # Sample random indices in test set
    rand_samples = np.random.choice(np.arange(Y_actual.shape[1]-1), 
                                    size=n_samples, replace=False)

    # Create figure
    fig, axs = plt.subplots(2, 4, figsize=(25, 10), sharex=False, sharey=True, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=.4, wspace=.2)
    fig.suptitle(f'Number of samples from the test set | {fname}')
    axs = axs.ravel()
    for i, s in enumerate(rand_samples):

        MAE = (np.abs(Y_actual[:, s] - Y_preds[:, s])).mean().item()
        axs[i].set_title(f'Sample {s} | Sample MAE: {MAE:.3f}')
        axs[i].plot(Y_actual[:, s], color='darkgreen', marker='o', markersize=10, alpha=.6, linestyle='', label='medium')
        axs[i].plot(Y_preds[:, s], color='b', marker='^', markersize=10, alpha=.6, linestyle='', label='medium')
        axs[i].set_xticks(list(range(n_players)))
        #axs[i].set_xticklabels(list(map(str, list(range(n_players)))) + ['ɛ'])
        
    # Make legend
    legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', lw=4, markersize=15, label='Actual'),
                       Line2D([0], [0], marker='^', color='w', markerfacecolor='b', lw=4, markersize=15,  label='Predicted'),
                    ]

    fig.legend(handles=legend_elements, fontsize=14, loc='upper right')
    sns.despine(offset=10, trim=False)
    plt.show();


def is_feasible_payoff(coals_min_win: torch.Tensor, payoffs_hat: torch.Tensor, eps_hat: torch.Tensor, tol: int):
    '''
    Checks if a given payoff vector (imputation) is feasible. 
    An imputation is in the [epsilon] core if and only if all its excesses are negative or zero.

    :param coals_min_win: the minimum set of winning coalitions in shape (1, num_coals, num_players)
    :param y_hat: the predicted payoffs
    :param eps_hat: the predicted epsweilon
    :param tol: number of decimals 
    :return: False if condition is not satisfied for >= coalition in c_win, True otherwise     
    '''
    v_C = float(1 - eps_hat) # Compute value of the coalition
    for coal_i in range(coals_min_win.shape[1]):
        win_coal = coals_min_win[:, coal_i, :]
        payoffs_rec = torch.where(win_coal == 1, payoffs_hat, win_coal).sum().item()
        excess = round((v_C - payoffs_rec), tol) # Compute the excess 
        if (excess > 0):  
            return False
    return True


def plot_example_preds(num_play_in_game: int, 
                       sols_actual: torch.Tensor, 
                       sols_pred: torch.Tensor, 
                       title: str, n_samples: int=8):
    '''
    Plots a number of example predictions in from the dataset
    :param num_play_in_game: Number of players in the game
    :param sols_actual: Labels (true payoff allocations)
    :param sols_pred: Predictions
    :param title: Title to display
    :param n_samples: Number of games to plot  
    '''

    # Sample random indices in test set
    rand_samples = np.random.choice(np.arange(sols_actual.shape[1]-1), 
                                    size=n_samples, replace=False)

    # Create figure
    fig, axs = plt.subplots(2, 4, figsize=(25, 10), sharex=False, sharey=True, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=.4, wspace=.2)
    fig.suptitle(f'Example games in the in sample | {title}')
    axs = axs.ravel()
    for i, s in enumerate(rand_samples):

        MAE = (np.abs(sols_actual[:, s] - sols_pred[:, s])).mean().item()
        #axs[i].set_title(f'{int(num_play_in_game[s])} Player game | Game MAE: {MAE:.3f}')
        axs[i].plot(sols_actual[:, s], color='darkgreen', marker='o', markersize=10, alpha=.6, linestyle='', label='medium')
        axs[i].plot(sols_pred[:, s], color='b', marker='^', markersize=10, alpha=.6, linestyle='', label='medium')
        axs[i].set_xticks(list(range(num_play_in_game)))
        axs[i].set_xlabel('Player')
        axs[i].set_ylabel('Payoff')
        
    # Make legend
    legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', lw=4, markersize=15, label='Actual'),
                       Line2D([0], [0], marker='^', color='w', markerfacecolor='b', lw=4, markersize=15,  label='Predicted'),
                    ]

    fig.legend(handles=legend_elements, fontsize=14, loc='upper right')
    sns.despine(offset=10, trim=False)
    plt.show();


def plot_best_games_lc(MAEs, n_players, sols_actual, sols_pred, top_n = 9):
    ''' Plot the n best predictions by the model '''
    
    best_to_worst_idx = np.argsort(MAEs)
    fig, axs = plt.subplots(3, 3, figsize=(20, 12), sharex=False, sharey=True, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=.4, wspace=.2)
    axs = axs.ravel()
    for i in range(top_n):
        sol_actual = sols_actual[:, best_to_worst_idx[i]]
        sol_pred = sols_pred[:, best_to_worst_idx[i]]
        sample_MMAE = MAEs[best_to_worst_idx[i]]
        axs[i].set_title(f'Game MAE: {sample_MMAE:.3f}')
        axs[i].plot(sol_actual, color='darkgreen', marker='o', markersize=9, alpha=.45, linestyle='', label='medium')
        axs[i].plot(sol_pred, color='b', marker='^', markersize=9, alpha=.45, linestyle='', label='medium')
        axs[i].set_xticks(list(range(n_players+1)))
        axs[i].set_xticklabels(list(map(str, list(range(n_players)))) + ['ɛ'])
        axs[i].set_xlabel('Player')
        axs[i].set_ylabel('Payoff')
        
    legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', lw=4, markersize=12, label='Actual'),
                       Line2D([0], [0], marker='^', color='w', markerfacecolor='b', lw=4, markersize=12,  label='Predicted'),
                       ]
    fig.legend(handles=legend_elements, fontsize=12, bbox_to_anchor=(0.9, 1))
    plt.show();

def plot_worst_games_lc(MAEs, n_players, sols_actual, sols_pred, top_n = 9):
    ''' Plot the n worst predictions by the model '''
    
    worst_to_best_idx = np.argsort(MAEs)[::-1]
    fig, axs = plt.subplots(3, 3, figsize=(20, 12), sharex=False, sharey=True, facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace=.4, wspace=.2)
    axs = axs.ravel()
    for i in range(top_n):
        sol_actual = sols_actual[:, worst_to_best_idx[i]]
        sol_pred = sols_pred[:, worst_to_best_idx[i]]
        sample_MMAE = MAEs[worst_to_best_idx[i]]
        axs[i].set_title(f'Game MAE: {sample_MMAE:.3f}')
        axs[i].plot(sol_actual, color='darkgreen', marker='o', markersize=9, alpha=.45, linestyle='', label='medium')
        axs[i].plot(sol_pred, color='b', marker='^', markersize=9, alpha=.45, linestyle='', label='medium')
        axs[i].set_xticks(list(range(n_players+1)))
        axs[i].set_xticklabels(list(map(str, list(range(n_players)))) + ['ɛ'])
        axs[i].set_xlabel('Player')
        axs[i].set_ylabel('Payoff')
        
    legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', lw=4, markersize=12, label='Actual'),
                       Line2D([0], [0], marker='^', color='w', markerfacecolor='b', lw=4, markersize=12,  label='Predicted'),
                       ]
    fig.legend(handles=legend_elements, fontsize=12, bbox_to_anchor=(0.9, 1))
    plt.show();



# def compare_sols_dist(X, Y_SV, Y_LC, Y_BI, player_rep):
#     fig, axs = plt.subplots(1, figsize=(10, 5),  facecolor='w', edgecolor='k')
#     fig.subplots_adjust(hspace=.6, wspace=.4)
#     fig.suptitle(f'Data distributions with {np.unique(player_rep)} players')

#     axs.hist([X.ravel(), Y_SV.ravel(), Y_BI.ravel(), Y_LC.ravel()], stacked=True, 
#             bins=70, color=['tab:blue', 'tab:purple', 'tab:orange', 'tab:green'], 
#             alpha=0.7);
#     axs.set_xlabel('Value')
#     axs.set_ylabel('Count')

#     # Make legend
#     legend_elements = [Line2D([0], [0], marker='s', color='w', markerfacecolor='tab:blue', lw=4, markersize=15, label='Normalized weights'),
#                     Line2D([0], [0], marker='s', color='w', markerfacecolor='tab:purple', lw=4, markersize=15,  label='Banzhaf index'),
#                     Line2D([0], [0], marker='s', color='w', markerfacecolor='tab:orange', lw=4, markersize=15,  label='Shapley value'),
#                     Line2D([0], [0], marker='s', color='w', markerfacecolor='tab:green', lw=4, markersize=15,  label='Least core')
#                     ]
#     fig.legend(handles=legend_elements, fontsize=14, loc='upper right')
#     sns.despine(offset=10, trim=False)


# def compare_sols_preds(X, Y_SV, Y_LC, Y_BI, player_rep, n_samples=8):  
#     # Sample random indices in test set
#     rand_samples = np.random.choice(np.arange(X.shape[0]), size=n_samples,replace=False)

#     fig, axs = plt.subplots(2, 4, figsize=(20, 10), sharex=False, sharey=True, facecolor='w', edgecolor='k')
#     fig.subplots_adjust(hspace=.6, wspace=.4)
#     fig.suptitle('Examples of games in the dataset | Comparison between solution concepts')
#     axs = axs.ravel()

#     for i, s in enumerate(rand_samples):
#         axs[i].set_title(f'Sample {s} | n = {player_rep[s]}')
#         axs[i].set_xlabel('Player')
#         axs[i].plot(X[s, :], color='tab:blue', marker='o', alpha=0.7, linestyle='')
#         axs[i].plot(Y_BI[s, :], color='tab:purple', marker='^', markersize=10, alpha=0.7, linestyle='')
#         axs[i].plot(Y_SV[s, :], color='tab:orange', marker='*', markersize=10, alpha=0.7, linestyle='')
#         axs[i].plot(Y_LC[s, :], color='tab:green', marker='d', markersize=10, alpha=0.7, linestyle='')

#     # Make legend
#     legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:blue', lw=4, markersize=15, label='Normalized weights'),
#                        Line2D([0], [0], marker='^', color='w', markerfacecolor='tab:purple', lw=4, markersize=15,  label='Banzhaf index'),
#                        Line2D([0], [0], marker='*', color='w', markerfacecolor='tab:orange', lw=4, markersize=15,  label='Shapley value'),
#                        Line2D([0], [0], marker='d', color='w', markerfacecolor='tab:green', lw=4, markersize=15,  label='Least core')
#                     ]
#     fig.legend(handles=legend_elements, fontsize=14, loc='upper right')
#     sns.despine(offset=10, trim=False)
#     plt.show();


# def visualize_exmp(X, Y, player_rep, n_samples=8):
#     # Sample random indices in test set
#     rand_samples = np.random.choice(np.arange(X.shape[0]), size=n_samples,replace=False)

#     fig, axs = plt.subplots(2, 4, figsize=(20, 10), sharex=False, sharey=True, facecolor='w', edgecolor='k')
#     fig.subplots_adjust(hspace=.6, wspace=.4)
#     fig.suptitle('Examples of games in the dataset')
#     axs = axs.ravel()
#     for i, s in enumerate(rand_samples):

#         axs[i].set_title(f'Sample {s} | n = {player_rep[s]}')
#         axs[i].set_xlabel('Player')
#         axs[i].plot(X[s, :], color='tab:blue', marker='o', alpha=0.7, linestyle='')
#         axs[i].plot(Y[s, :], color='tab:purple', marker='^', markersize=10, alpha=0.7, linestyle='')

#     # Make legend
#     legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:blue', lw=4, markersize=15, label='Normalized weights'),
#                        Line2D([0], [0], marker='^', color='w', markerfacecolor='tab:purple', lw=4, markersize=15,  label='Solution'),
#                     ]
#     fig.legend(handles=legend_elements, fontsize=14, loc='upper right')
#     sns.despine(offset=10, trim=False)
#     plt.show();



# legend_elements = [Line2D([0], [0], color=cb_map[0],  lw=3, label='Fixed-size GAE'),
#                    Line2D([0], [0], color=cb_map[2], lw=3, label='Fixed-size MAE'),
#                    Line2D([0], [0], color=cb_map[1], lw=3, label='Var-size GAE'),
#                    Line2D([0], [0], color=cb_map[3], lw=3, label='Var-size MAE'),
#                 ]
