import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
# import statsmodels.api as sm
import scipy
import numpy as np
import torch
from sklearn import preprocessing
from scipy import stats
import warnings 
warnings.filterwarnings('ignore')
from scipy.stats import norm,gamma,expon
import glob
import random
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

import math

def Hill_estimator(data):
    """
    Returns the Hill Estimators for some 1D data set.
    """    
    # sort data in such way that the smallest value is first and the largest value comes last:
    data = data[~np.isnan(data)]
    Y = np.sort(data)
    n = min(len(Y),1500)

    Hill_est = np.zeros(n-1)

    for k in range(0, n-1):    # k = 0,...,n-2
        summ = 0

        for i in range(0,k+1):   # i = 0, ..., k
            summ += np.log(Y[n-1-i]/(Y[n-2-k]+0.0001))
        
        Hill_est[k] = (1 / (k+1)) * summ      # add 1 to k because of Python syntax
  
    kappa = 1. / Hill_est
    return kappa, (n-1)





def empirical_mgf(samples, t):
    """Calculates the empirical moment generating function for given samples."""
    samples = samples - np.mean(samples)
    return np.log(np.mean(np.exp(t * samples)))


def plot_anthropic_plots(data,name,savepath):
    fig = plt.figure(figsize=(25,20)) 
    fig.suptitle(f"{name} Distribution Plots")
    gs = gridspec.GridSpec(3, 3)
    ax0 = plt.subplot(gs[0, 0])
    sns.histplot(x=data,stat="frequency",ax=ax0)
    ax0.set_title("Histograms of rewards")
    ax01 = plt.subplot(gs[0, 1])
    hill_estimator_op,n = Hill_estimator(data)
    # ax4 = plt.subplot(gs[4, 0])
    sns.lineplot(x=[i for i in range(0,n)], y=hill_estimator_op,ax=ax01)
    ax01.set_title("Hill Estimator Plot")
    ax02 = plt.subplot(gs[0, 2])
    t_values = np.linspace(-10, 10, 100)  # Range of t values to evaluate the MGF
    mgf_values = [empirical_mgf(data, t) for t in t_values]
    ax02.plot(t_values,mgf_values,label="Empirical MGF")
    coefficients = np.polyfit(t_values, mgf_values, 2)
    polynomial = np.poly1d(coefficients)
    y_fit = polynomial(t_values)
    ax02.plot(t_values, y_fit, color='red', label='Fitted polynomial',linestyle=':')
    ax02.legend()
    ax02.set_title("Empirical MGF Plot with fitted curve")
    ### Fit different distributions to the data
    ### Normal distribution 
    norm_params = norm.fit(data)
    lambda_param = expon.fit(data)
    gamma_shape = gamma.fit(data)
    # print(norm_params,lambda_param, gamma_shape)
    ax10 = plt.subplot(gs[1, 0])
    ax10.hist(data, bins=75,density=True, alpha=0.4)
    xmin, xmax = ax10.get_xlim()
    x = np.linspace(xmin, xmax, 1000)
    p = norm.pdf(x,*norm_params)
    ax10.plot(x, p, 'k', linewidth=3,label='Fitted Normal Distribution')
    # norm_pdf_data = norm.pdf(np.sort(data),*norm_params) 
    # ax10.plot(np.sort(data), norm_pdf_data, color='green', label='Fitted Normal Distribution')
    ax10.legend()
    ax10.set_title("Normal Distribution Plot")
    
    ax11 = plt.subplot(gs[1, 1])
    ax11.hist(data, bins=75,density=True, alpha=0.4)
    xmin, xmax = ax11.get_xlim()
    x = np.linspace(xmin, xmax, 1000)
    p = expon.pdf(x,*lambda_param)
    ax11.plot(x, p, 'k', linewidth=3,label='Fitted Exponential Distribution')
    # expon_pdf_data = expon.pdf(np.sort(data),*lambda_param) 
    # ax11.plot(np.sort(data), expon_pdf_data, color='green', label='Fitted Exponential Distribution')
    ax11.legend()
    ax11.set_title("Exponential Distribution Plot")
    
    ax12 = plt.subplot(gs[1, 2])
    ax12.hist(data, bins=75,density=True, alpha=0.4)
    xmin, xmax = ax12.get_xlim()
    x = np.linspace(xmin, xmax, 1000)
    p = gamma.pdf(x,*gamma_shape) 
    ax12.plot(x, p, 'k', linewidth=3,label='Fitted Gamma Distribution')
    # gamma_pdf_data = gamma.pdf(np.sort(data),*gamma_shape) 
    # ax12.plot(np.sort(data), gamma_pdf_data, color='green', label='Fitted Gamma Distribution')
    ax12.legend()
    ax12.set_title("Gamma Distribution Plot")
    
    
    #### QQ Plots
    ax20 = plt.subplot(gs[2, 0])
    stats.probplot(data, plot=ax20,fit=True,sparams=norm_params)
    ax20.set_title("Normal Quantile Plot")
    ax21 = plt.subplot(gs[2, 1])
    # sm.qqplot(data,dist=scipy.stats.distributions.expon,line="r",ax=ax2,fit=True)
    stats.probplot(data, dist=scipy.stats.distributions.expon,fit=True, plot=ax21,
                   sparams=lambda_param)
    ax21.set_title("Exponential Quantile Plot")
    ax22 = plt.subplot(gs[2, 2])
    # sm.qqplot(data,dist=scipy.stats.distributions.gamma,line="r",ax=ax3,fit=True)
    stats.probplot(data, dist=scipy.stats.distributions.gamma,fit=True, plot=ax22,
                   sparams=gamma_shape)
    ax22.set_title("Gamma Quantile Plot")
    
    plt.savefig(savepath+name+".pdf")
    plt.close(fig)








def get_hist_mean_sd(x):
    mean = np.round(np.mean(x),2)
    sd = np.round(np.std(x),2)
    max_reward = np.round(np.max(x),2)
    return mean,sd,max_reward




def plot_histograms(file,title,savepath):
    fig = plt.figure(figsize=(14,4)) 
    gs = gridspec.GridSpec(1, 4)
    hist_details = {}
    model_name = file.split("/")[1].split("_samples")[0]
    model_name = model_name.replace("ibm-","")
    model_name = model_name.replace("ibm_","")
    model_name = model_name.replace("meta-llama_","")
    model_name = model_name.replace("mistralai_","")
    with open(file,"rb") as f:
        data = pickle.load(f)
        ax = plt.subplot(gs[0,0])
        ax.hist(data,bins=75,label=model_name)
        ax.set_title(f"All rewards")
        data = data.T
        m,sd,max_r = get_hist_mean_sd(data)
        hist_details["original"] = {"mean":m,"SD":sd,"Max":max_r}
        mean_ref = np.mean(data)
        formatted_data = np.reshape(data, (-1, 100))
        # print(formatted_data.shape)
        x,y = [],[]
        vals = [10,50,100]
        
        for idx,n in enumerate(vals):      
            subset_data = formatted_data[:,:n]
            # print(subset_data.shape)
            op = np.max(subset_data,axis=1)
            # x.append(np.mean(op) - mean_ref)
            # y.append(np.log(n)- ((n-1)/n))
            m,sd,max_r = get_hist_mean_sd(op)
            hist_details[f"Best{n}"] =  {"mean":m,"SD":sd,"Max":max_r}
            ax = plt.subplot(gs[0,idx+1])
            ax.hist(op,label=model_name)
            ax.set_title(f"n = {n}")
        
    # plt.legend(loc='best')
    # plt.grid(color='grey', linestyle='-', linewidth=0.3,alpha=0.3)
    # plt.xlabel("log(n) - (n-1)/n")
    # plt.ylabel("Mean N best - Mean Ref")
    fig.suptitle(title)
    plt.savefig(savepath)
    plt.show()
    plt.close(fig)
    return hist_details



def get_x_y(file):
    model_name = file.split("/")[1].split("_samples")[0]
    model_name = model_name.replace("ibm-","")
    model_name = model_name.replace("ibm_","")
    model_name = model_name.replace("meta-llama_","")
    model_name = model_name.replace("mistralai_","")

    with open(file,"rb") as f:
        data = pickle.load(f)
        data = data.T
        mean_ref = np.mean(data)
        formatted_data = np.reshape(data, (-1, 100))
        # print(formatted_data.shape)
        x,y = [],[]
        for n in range(2,100):
            subset_data = formatted_data[:,:n]
            # print(subset_data.shape)
            op = np.max(subset_data,axis=1)
            y.append(np.mean(op) - mean_ref)
            x.append(np.log(n)- ((n-1)/n))
        
    return np.array(x),np.array(y),model_name


