from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.stats import gaussian_kde

import matplotlib.patches as mpatches
from matplotlib.lines import Line2D

# compute tpr from input
def compute_tpr(s0, s1, alpha):
    lam = np.quantile(s0, 1-alpha)
    return 1-(s1 < lam).mean()

# compute tpr from input
def compute_fpr(s0, s1, q):
    lam = np.quantile(s1, 1-q)
    return 1-(s0 < lam).mean()

def compute_evalfpr(instances, window_sys):
    valid_instances = instances[-window_sys:]
    tn = np.sum(valid_instances=='tn')
    fp = np.sum(valid_instances=='fp')
    sys_fpr = fp/(fp+tn) if fp+tn != 0 else 0
    return sys_fpr

def compute_evaltpr(instances, window_sys):
    valid_instances = instances[-window_sys:]
    fn = np.sum(valid_instances=='fn')
    tp = np.sum(valid_instances=='tp')
    sys_tpr = tp/(tp+fn) if tp+fn != 0 else 0
    return sys_tpr

def get_evalfpr_lst(instances, window_sys):
    sysfpr_lst = []
    for i in tqdm(range(1, len(instances)+1)):
        sysfpr_lst.append(100*compute_evalfpr(instances[:i], window_sys))
    return sysfpr_lst

def get_evaltpr_lst(instances, window_sys):
    systpr_lst = []
    for i in tqdm(range(1, len(instances)+1)):
        systpr_lst.append(100*compute_evaltpr(instances[:i], window_sys))
    return systpr_lst

# plot score distributions of s0 and s1
def plot_score_distrs(s0, s1, title, alpha):
    kde0 = gaussian_kde(s0, bw_method='scott')  
    kde1 = gaussian_kde(s1, bw_method='scott')  

    x_min = min(min(s0), min(s1))
    x_max = max(max(s0), max(s1))
    x_range = np.linspace(x_min-0.1, x_max+0.1, 1000)

    d0 = kde0(x_range)
    d1 = kde1(x_range)

    plt.figure(figsize=(5,4))

    # plot lambda at FPR 5
    lam_g = np.quantile(s0, 1-alpha)
    tpr = 1-(s1 < lam_g).mean()
    plt.axvline(x=lam_g, color='r', linestyle='--', lw=3, label=f'TPR {int(round(tpr*100, 2))}%')

    plt.plot(x_range, d0, lw=2, color='orange')
    plt.fill_between(x_range, d0, color='#FDBB84', alpha=0.6)

    plt.plot(x_range, d1, lw=2, color='#5A9BCF')
    plt.fill_between(x_range, d1, color='#A0C4E5', alpha=0.6)

    plt.ylim(bottom=0)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel('Scores', fontsize=18)
    plt.ylabel('Density', fontsize=18)
    plt.legend(loc='upper left', fontsize=18)
    plt.savefig(f'./{title}.png', dpi=100, bbox_inches='tight')
    plt.close()

def plot_simple(fpr_lst, tpr_lst):
    plt.figure(figsize=(15,5))

    # plot fpr
    fpr_mean = np.mean(np.array(fpr_lst), axis=0)
    fpr_std = np.std(np.array(fpr_lst), axis=0)
    plt.subplot(1, 2, 1)
    plt.plot(fpr_mean, color='orange')
    plt.fill_between(range(len(fpr_mean)), np.maximum(fpr_mean-fpr_std, 0), fpr_mean+fpr_std, color='orange', alpha=0.15)
    plt.axhline(0.05 * 100, lw=3, color='black', linestyle='--')

    # plot fpr
    tpr_mean = np.mean(np.array(tpr_lst), axis=0)
    tpr_std = np.std(np.array(tpr_lst), axis=0)
    plt.subplot(1, 2, 2)
    plt.plot(tpr_mean, color='green')
    plt.fill_between(range(len(tpr_mean)), np.maximum(tpr_mean-tpr_std, 0), tpr_mean+tpr_std, color='green', alpha=0.15)

    plt.show()
    plt.close() 

# Plot for stataionary settings
def plot_stationary(fpr_methods, tpr_methods, tpr_upper_bounds, shift_point, alpha, N, title, dst_path):

    fig = plt.figure(figsize=(15, 6))
    
    # Colors and markers for each method
    colors = ['#428bca', '#4cae4b', '#f37735']
    markers = ['s', 'x', '^', '|']

    # Create subplots
    ax0 = plt.subplot2grid((4, 2), (0, 0))  # Upper-left
    ax2 = plt.subplot2grid((4, 2), (0, 1))  # Upper-right
    ax1 = plt.subplot2grid((4, 2), (1, 0), rowspan=3)  # Lower-left
    ax3 = plt.subplot2grid((4, 2), (1, 1), rowspan=3)  # Lower-right 

    def plot(ax, methods, index, marker, color):
        mean = np.mean(np.array(methods[index]), axis=0)
        std = np.std(np.array(methods[index]), axis=0)
        if index % 2 == 0:
            d = 25000
        else:
            d = 30000
        ax.plot(range(N), mean, lw=3, marker=marker, markersize=15, markeredgewidth=1, markevery=range(d, 150000, d), color=color)
        ax.fill_between(range(N), np.maximum(mean - std, 0), mean + std, color=color, alpha=0.15)
        return np.mean(mean) 

    ### FPR ###
    m = plot(ax0, fpr_methods, 0, markers[0], colors[0])
    ax0.set_ylim(m-10, m+10)
    ax0.spines['bottom'].set_visible(False)  
    ax0.set_xticks([]) 
    ax0.tick_params(axis='y', labelsize=22)
    ax0.set_xlabel('')
    ax0.set_title('FPR(%)', fontsize=27)
    ax0.set_xlim(left=0, right=150000)

    ax1.spines['top'].set_visible(False)  
    plot(ax1, fpr_methods, 1, markers[1], colors[1])
    plot(ax1, fpr_methods, 2, markers[2], colors[2])
    ax1.axhline(alpha * 100, lw=3, color='black', linestyle='--', label=f'FPR-{alpha*100}')
    ax1.set_xlabel('time(t)', fontsize=24)   
    ax1.set_xticks([0, 50000, 100000, 150000])
    ax1.set_xticklabels(['0', '50k', '100k', '150k'], fontsize=22)
    ax1.tick_params(axis='y', labelsize=22)
    ax1.set_xlim(left=0, right=150000)
    ax1.set_ylim(bottom=0)

    ### TPR ###
    m = plot(ax2, tpr_methods, 0, markers[0], colors[0])
    if shift_point != None:
        ax2.axvline(shift_point, color='purple', linestyle='--', label=f'shift')
    ax2.set_ylim(m-10, max(m+10, 100))
    ax2.spines['bottom'].set_visible(False)  
    ax2.set_xticks([]) 
    ax2.tick_params(axis='y', labelsize=22)
    ax2.set_xlabel('')
    ax2.set_title('TPR(%)', fontsize=27) 
    ax2.set_xlim(left=0, right=150000)

    ax3.spines['top'].set_visible(False)  
    plot(ax3, tpr_methods, 1, markers[1], colors[1])
    plot(ax3, tpr_methods, 2, markers[2], colors[2])
    ax3.set_xlabel('time(t)', fontsize=24)  
    ax3.set_xticks([0, 50000, 100000, 150000])
    ax3.set_xticklabels(['0', '50k', '100k', '150k'], fontsize=22)
    ax3.tick_params(axis='y', labelsize=22)
    ax3.set_xlim(left=0, right=150000)
    ax3.set_ylim(bottom=0)

    # Omission sign
    fig.text(0, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax1.transAxes)
    fig.text(1, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax1.transAxes)
    fig.text(0, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax3.transAxes)
    fig.text(1, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax3.transAxes)

    # Upper bound
    if tpr_upper_bounds != []:
        ax3.axhline(tpr_upper_bounds[0]*100, lw=3, color=colors[1], linestyle='dashdot')
        ax3.axhline(tpr_upper_bounds[1]*100, lw=3, color=colors[2], linestyle='dashdot')

    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(dst_path, dpi=100, bbox_inches='tight')

# Plot stats for multiple methods and multiple runs
def plot_distr_shift(fpr_methods, tpr_methods, tpr_upper_bounds, alpha, N, title, dst_path='./result.png'):

    fig = plt.figure(figsize=(15, 6))
    
    # Colors and markers for each method
    colors = ['#428bca', '#4cae4b', '#f37735']
    markers = ['s', 'x', '^', '|']

    # Create subplots
    ax0 = plt.subplot2grid((4, 2), (0, 0))  # Upper-left
    # ax2 = plt.subplot2grid((4, 2), (0, 1))  # Upper-right
    ax1 = plt.subplot2grid((4, 2), (1, 0), rowspan=3)  # Lower-left
    ax3 = plt.subplot2grid((4, 2), (0, 1), rowspan=4)  # Lower-right 

    def plot(ax, methods, index, marker, color):
        mean = np.mean(np.array(methods[index]), axis=0)
        std = np.std(np.array(methods[index]), axis=0)
        if index % 2 == 0:
            d = 25000
        else:
            d = 30000
        ax.plot(range(N), mean, lw=3, marker=marker, markersize=15, markeredgewidth=1, markevery=range(d, 150000, d), color=color)
        ax.fill_between(range(N), np.maximum(mean - std, 0), mean + std, color=color, alpha=0.15)
        return np.mean(mean), np.quantile(mean, 0.2), np.quantile(mean, 0.8) 

    ### FPR ###
    m, q1, q2 = plot(ax0, fpr_methods, 0, markers[0], colors[0])
    # if shift_point != None: 
    #     ax0.axvline(shift_point, color='purple', linestyle='--', label=f'shift')
    ax0.set_ylim(q1-5, q2+5)
    ax0.spines['bottom'].set_visible(False)  
    ax0.set_xticks([]) 
    ax0.tick_params(axis='y', labelsize=22)
    ax0.set_xlabel('')
    ax0.set_title('FPR(%)', fontsize=27)
    ax0.set_xlim(left=0, right=175000)

    ax1.spines['top'].set_visible(False)  
    plot(ax1, fpr_methods, 1, markers[1], colors[1])
    plot(ax1, fpr_methods, 2, markers[2], colors[2])
    ax1.axhline(alpha * 100, lw=3, color='black', linestyle='--', label=f'FPR-{alpha*100}')
    # if shift_point != None: 
    #     ax1.axvline(shift_point, color='purple', linestyle='--', label=f'shift')
    ax1.set_xlabel('time(t)', fontsize=25)   
    ax1.set_xticks([20000, 50000, 80000, 110000, 140000, 170000])
    ax1.set_xticklabels(['20k', '50k', '80k', '110k', '140k', '170k'], fontsize=22)
    ax1.tick_params(axis='y', labelsize=22)
    ax1.set_xlim(left=0, right=175000)
    ax1.set_ylim(bottom=0)

    ### TPR ###
    m, q1, q2 = plot(ax3, tpr_methods, 0, markers[0], colors[0])
    # if shift_point != None:
    #     ax2.axvline(shift_point, color='purple', linestyle='--', label=f'shift')
    # ax2.set_ylim(q1, max(q2, 100))
    # ax2.spines['bottom'].set_visible(False)  
    # ax2.set_xticks([]) 
    # ax2.tick_params(axis='y', labelsize=15)
    # ax2.set_xlabel('')
    # ax2.set_title('TPR(%)', fontsize=17) 

    # ax3.spines['top'].set_visible(False)  
    plot(ax3, tpr_methods, 1, markers[1], colors[1])
    plot(ax3, tpr_methods, 2, markers[2], colors[2])
    # if shift_point != None:
    #     ax3.axvline(shift_point, color='purple', linestyle='--', label=f'shift')
    ax3.set_xlabel('time(t)', fontsize=25)  
    ax3.set_xticks([20000, 50000, 80000, 110000, 140000, 170000])
    ax3.set_xticklabels(['20k', '50k', '80k', '110k', '140k', '170k'], fontsize=22)
    ax3.tick_params(axis='y', labelsize=22)
    ax3.set_ylim(0, 103)
    ax3.set_xlim(left=0, right=175000)
    ax3.set_title('TPR(%)', fontsize=27)

    # Omission sign
    fig.text(0, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax1.transAxes)
    fig.text(1, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax1.transAxes)
    # fig.text(0, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax3.transAxes)
    # fig.text(1, 1.025, '~', fontsize=30, ha='center', va='center', transform=ax3.transAxes)

    # Upper bound
    if tpr_upper_bounds != []:
        ax3.axhline(tpr_upper_bounds[0][0]*100, xmin=0, xmax=0.3, lw=2.5, color=colors[1], linestyle='dashdot')
        ax3.axhline(tpr_upper_bounds[0][1]*100, xmin=0.3, lw=2.5, color=colors[1], linestyle='dashdot') 
        plt.plot([50000, 50000], [tpr_upper_bounds[0][0]*100, tpr_upper_bounds[0][1]*100], lw=2.5, color=colors[1], linestyle='dashdot')

        ax3.axhline(tpr_upper_bounds[1][0]*100, xmin=0, xmax=0.3, lw=2.5, color=colors[2], linestyle='dashdot')
        ax3.axhline(tpr_upper_bounds[1][1]*100, xmin=0.3, lw=2.5, color=colors[2], linestyle='dashdot') 
        plt.plot([50000, 50000], [tpr_upper_bounds[1][0]*100, tpr_upper_bounds[1][1]*100], lw=2.5, color=colors[2], linestyle='dashdot')

    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(dst_path, dpi=100, bbox_inches='tight')