import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import datetime
from tqdm import tqdm
import time
import signal
import heapq

class Setting:
    def __init__(self, 
                 n_values, 
                 num_trials, 
                 time_limit = None,
                 benchmark_fn = None,
                 tail_quantile_fn = None,
                 normalization = None,
                 result_filename = None,
                 top_values = 4,
                 Fplus = False,
                 ):
        self.n_values = n_values  # List of sample sizes
        self.num_trials = num_trials  # Number of trials per experiment
        self.benchmark_fn = benchmark_fn  # Benchamrk function for the growth rate
        self.tail_quantile_fn = tail_quantile_fn
        self.normalization = normalization
        self.result_filename = result_filename
        self.top_values = top_values
        self.time_limit = time_limit
        self.Fplus = Fplus

        # Take the current timestamp if the result_file_name is not given
        if self.result_filename is None:

            current_date_time = datetime.datetime.now().timestamp()
            current_date_time = datetime.datetime.fromtimestamp(
                current_date_time).strftime('%Y-%m-%d-%H-%M-%S')
            self.result_filename = str(current_date_time)
        
        if self.normalization is None:
            self.normalization = 1

    def __str__(self):
        return f"Setting(n_values={self.n_values}, num_trials={self.num_trials}, benchmark_fn={self.benchmark_fn.__name__}, tail_quantile_fn={self.tail_quantile_fn.__name__})"

    def __repr__(self):
        return self.__str__()
    
    def trial(self, n):
        V = 1 / np.random.uniform(0, 1, n)  # 1 + Unit-Pareto variable
        if self.Fplus:
            Vceil = np.ceil(V)
            samples = self.tail_quantile_fn(Vceil)
        else:
            samples = self.tail_quantile_fn(V)
        return samples/self.normalization
    
    def run_trials(self):

        results = []
        
        for n in self.n_values:
            print(f"Running trials for n = {n}...")
            sampling_times = []
            findmax_times = []
            for _ in tqdm(range(self.num_trials)):
                time0 = time.time()
                trial_result = self.trial(n)
                time1 = time.time()
                sampling_times.append(time1 - time0)
                # sorted_result = np.sort(trial_result)
                top_values = heapq.nlargest(self.top_values, trial_result)
                findmax_times.append(time.time() - time1)

                results.append({
                    'n': n,
                    'average': np.mean(trial_result),
                    'max': top_values[0],
                    'second_max': top_values[1],
                    'third_max': top_values[2],
                    'fourth_max': top_values[3]
                    })
                
            print(f"Total sampling time for n = {n}: {np.sum(sampling_times)} seconds")
            print(f"Total findmax time for n = {n}: {np.sum(findmax_times)} seconds")
        
        results_df = pd.DataFrame(results)
        self.results_df = results_df
        self.summarize_results()
        self.save_results()
        self.compute_percentiles()
        # trimmed_df = self.trimmed_df(remove_percentage=1)
        # self.results_df = trimmed_df
        # self.summarize_results(filename=self.result_filename+'_trimmed_1')
        # self.compute_percentiles(num_Q=101, filename=self.result_filename+'_percentiles')




    def save_results(self, filename = None):
        if filename is None:
            filename = self.result_filename
        # Save the results DataFrame to a CSV file
        self.results_df.to_csv(filename+'.csv', index=False)
        print(f"Results saved to {filename}")

        # # Save the self as a pickle file
        # with open(filename + '_the_setting.pickle', 'wb') as handle:
        #             pickle.dump(self, handle, protocol = pickle.HIGHEST_PROTOCOL)

    def summarize_results(self, filename = None):
        if filename is None:
            filename = self.result_filename + '_full_summary'
        self.results_df['ratio1'] = self.results_df['max'] / self.results_df['second_max']
        self.results_df['ratio2'] = self.results_df['second_max'] / self.results_df['third_max']
        self.results_df['ratio3'] = self.results_df['third_max'] / self.results_df['fourth_max']
        self.summary_df = self.results_df.groupby("n").agg(["mean", "std","count"])
        self.summary_df.to_csv(filename+'.csv', index=False)

    def compute_percentiles(self, num_Q=101, filename = None):
        if filename is None:
            filename = self.result_filename + '_percentiles'
        percentiles = np.linspace(0, 100, num_Q)
        # Create an empty list to store results
        results = []

        self.results_df['ratio1'] = self.results_df['max'] / self.results_df['second_max']
        self.results_df['ratio2'] = self.results_df['second_max'] / self.results_df['third_max']
        self.results_df['ratio3'] = self.results_df['third_max'] / self.results_df['fourth_max']
        # Group by 'n' and calculate percentiles for each column
        for n, group in self.results_df.groupby('n'):
            for p in percentiles:
                # Calculate percentiles for each column
                row = {
                    'n': n,
                    'percentile': p,
                    'average': group['average'].quantile(p / 100),
                    'max': group['max'].quantile(p / 100),
                    'second_max': group['second_max'].quantile(p / 100),
                    'third_max': group['third_max'].quantile(p / 100),
                    'fourth_max': group['fourth_max'].quantile(p / 100),
                    'ratio1': group['ratio1'].quantile(p / 100),
                    'ratio2': group['ratio2'].quantile(p / 100),
                    'ratio3': group['ratio3'].quantile(p / 100),
                }
                results.append(row)

        # Convert the results into a new DataFrame
        self.percentile_df = pd.DataFrame(results)
        self.percentile_df.to_csv(filename+'.csv', index=False)

    def trimmed_df(self, remove_percentage = 1):
        df = self.results_df
        remove_nsamples = int(self.num_trials*remove_percentage/100)
        print(f'Removing {remove_nsamples} samples from total of {self.num_trials}')
        remove_indices = []
        for i in range(len(self.n_values)):
            max_samples = df[df['n'] == self.n_values[i]]['max'].values
            #indices for which max is 1% largest
            top_values = np.sort(max_samples)[self.num_trials-remove_nsamples:]
            for x in top_values:
                remove_indices.append(i*self.num_trials + np.where(max_samples == x)[0][0])
        #remove rows from df
        df_trimmed = df.drop(remove_indices)
        return df_trimmed.reset_index(drop=True)

    def sqrt_df(self):
        df = self.results_df
        df['average'] = np.sqrt(df['average'])
        df['max'] = np.sqrt(df['max'])
        df['second_max'] = np.sqrt(df['second_max'])
        df['third_max'] = np.sqrt(df['third_max'])
        df['fourth_max'] = np.sqrt(df['fourth_max'])
        return df.reset_index(drop=True)
        
        

    
