import pandas as pd
import numpy as np


class ModelSelection:
    def __init__(self, df):
        self.df = df
        self.df['Selection-Method'] = self.NAME

        # doing this to break ties
        # we will never report accuracies to the granularity of 1/1000
        # so this won't impact our results
        num_rows = len(self.df.index)
        noise = np.random.rand(num_rows) / 10000.
        self.df['Metric-Value'] += noise
        
        metric_name = self.df['Metric-Name'].iloc[0]
        self.sort_ascending = False if 'Accuracy' in metric_name else True

        # call sub-class
        self.validation_df, self.test_df = self.select_epoch()

        self.validation_df['trial_rank'] = self.validation_df.groupby(
            'trial_seed'
        )['Metric-Value'].rank(method='dense', ascending=self.sort_ascending)
        self.test_df['trial_rank'] = self.validation_df['trial_rank'].tolist()
        

        self.trial_values = []
        self.trial_best_identifiers = []
        for _, df in self.test_df.groupby('trial_seed'):
            self.trial_values.append(
                df[df.trial_rank == 1.0]['Metric-Value'].iloc[0])
            
            # identifiers used to backout scores of other metrics
            selected_row = df[df.trial_rank == 1.0].iloc[0]
            self.trial_best_identifiers.append({
                'epoch': selected_row['Epoch'],
                'seed': selected_row['seed']
            })
            
    def backout_other_metrics(self, sweep_df):

        trial_values = {m: [] for m in sweep_df['Metric-Name'].unique()}
        for trial, trial_df in sweep_df.groupby('trial_seed'):
            identifiers = self.trial_best_identifiers[trial]
            trial_df = trial_df[
                (trial_df.seed == identifiers['seed']) & 
                (trial_df.Epoch == identifiers['epoch']) & 
                (trial_df.Split == 'Test')
            ]
            
            values = dict(zip(
                trial_df['Metric-Name'],
                trial_df['Metric-Value']
            ))
            for metric, val in values.items():
                trial_values[metric].append(val)

        return trial_values

class LastStep(ModelSelection):
    """Model selection from the *last* step of training."""

    NAME = 'LastStep'

    def __init__(self, df):
        super(LastStep, self).__init__(df)

    def select_epoch(self):
        last_step = max(self.df.Epoch.unique())
        self.df = self.df[self.df.Epoch == last_step]

        validation_df = self.df[self.df.Split == 'Validation'].copy()
        test_df = self.df[self.df.Split == 'Test'].copy()

        return validation_df, test_df
        
class EarlyStop(ModelSelection):
    """Model selection from the *best* step of training."""

    NAME = 'EarlyStop'

    def __init__(self, df):
        super(EarlyStop, self).__init__(df)

    def select_epoch(self):
        validation_df = self.df[self.df.Split == 'Validation']
        test_df = self.df[self.df.Split == 'Test']

        validation_dfs, test_dfs = [], []
        for (t, s), df in validation_df.groupby(['trial_seed', 'seed']):
            best_epoch = df[df['Metric-Value'] == self.find_best(df)]['Epoch'].iloc[0]

            validation_dfs.append(
                df[df.Epoch == best_epoch])

            test_dfs.append(
                test_df[
                    (test_df.Epoch == best_epoch) &
                    (test_df.seed == s) & 
                    (test_df.trial_seed == t)])
            
        validation_df = pd.concat(validation_dfs, ignore_index=True)
        test_df = pd.concat(test_dfs, ignore_index=True)

        return validation_df, test_df

    def find_best(self, df):
        if self.sort_ascending is False:
            return df['Metric-Value'].max()
        return df['Metric-Value'].min()
