import sys
import os

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def qqplot(sample, *samples, n_points=100, ax=None):
    quantiles = np.linspace(0, 100, n_points)

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title("QQ Plot")
    ax.grid(True)

    # Check for empty sample
    if len(sample) == 0:
        print("[WARN] qqplot: Reference sample is empty, skipping plot.")
        return ax, []

    q = np.percentile(sample, quantiles)
    mean_ref = sample.mean()

    scatters = []
    for s in samples:
        if len(s) == 0:
            print("[WARN] qqplot: One of the compared samples is empty, skipping this scatter.")
            continue
        q_s = np.percentile(s, quantiles)
        scatter = ax.scatter(q, q_s, alpha=0.5)
        scatters.append(scatter)

        color = scatter.get_facecolor()[0]
        mean_s = s.mean()
        ax.plot(mean_ref, mean_s, 'o', markerfacecolor=color, markeredgecolor='k', markersize=10)

    all_q = [q] + [np.percentile(s, quantiles) for s in samples if len(s) > 0]
    if all_q:
        qmin = np.min([a.min() for a in all_q])
        qmax = np.max([a.max() for a in all_q])
        ax.plot([qmin, qmax], [qmin, qmax], 'k--', alpha=0.75)

    return ax, scatters

def cdfplot(*samples, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title("CDF Plot")
    ax.grid(True)

    lines = []
    for s in samples:
        if len(s) == 0:
            print("[WARN] cdfplot: One of the samples is empty, skipping this line.")
            continue
        sorted_sample = np.sort(s)
        vals = np.arange(1, len(sorted_sample) + 1) / len(sorted_sample)
        line, = ax.plot(sorted_sample, vals)
        lines.append(line)

        color = line.get_color()
        mean_s = s.mean()
        ax.axvline(mean_s, color=color, linestyle='--')
    
    return ax, lines

def densityplot(*samples, ax=None, bw_method='scott'):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title("Density Plot")
    ax.grid(True)

    lines = []
    for s in samples:
        if len(s) == 0:
            print("[WARN] densityplot: One of the samples is empty, skipping this line.")
            continue
        kde = gaussian_kde(s, bw_method=bw_method)
        xs = np.linspace(np.min(s), np.max(s), 1000)
        ys = kde(xs)
        line, = ax.plot(xs, ys)
        lines.append(line)

        color = line.get_color()
        mean_s = s.mean()
        ax.axvline(mean_s, color=color, linestyle='--')
    
    return ax, lines

def plot_train_test_val(train, test, val, labels=['Train', 'Test', 'Val']):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # QQ Plot
    _, scatters = qqplot(test, train, val, ax=axs[0])
    if scatters:
        scatters[0].set_label(labels[0])
        if len(scatters) > 1:
            scatters[1].set_label(labels[2])

    axs[0].set_xlabel(f'{labels[1]} Set Quantiles')
    axs[0].set_ylabel(f'{labels[0]}/{labels[2]} Set Quantiles')
    axs[0].set_title('QQ Plot')
    axs[0].legend()

    # CDF Plot
    _, lines = cdfplot(train, test, val, ax=axs[1])
    for idx, line in enumerate(lines):
        if idx < len(labels):
            line.set_label(labels[idx])

    axs[1].set_xlabel('Values')
    axs[1].set_ylabel('Cumulative Probability')
    axs[1].set_title('CDF Plot')
    axs[1].legend()

    # Density Plot
    _, lines = densityplot(train, test, val, ax=axs[2])
    for idx, line in enumerate(lines):
        if idx < len(labels):
            line.set_label(labels[idx])

    axs[2].set_xlabel('Values')
    axs[2].set_ylabel('Density')
    axs[2].set_title('Density Plot')
    axs[2].legend()

    return fig, axs

def plot_pairwise(train, val, labels=['Train', 'Val']):
    """
    Create a pairwise comparison plot between two datasets with QQ, CDF, and Density plots.
    """
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    
    # Check for empty arrays before plotting
    if len(train) == 0 or len(val) == 0:
        print("[WARN] plot_pairwise: One or both input arrays are empty, skipping plots.")
        return fig, axs

    _, scatters = qqplot(train, val, ax=axs[0])
    if scatters:
        scatters[0].set_label(f'{labels[1]} vs {labels[0]}')
    
    axs[0].set_xlabel(f'{labels[0]} Quantiles')
    axs[0].set_ylabel(f'{labels[1]} Quantiles')
    axs[0].set_title('QQ Plot')
    axs[0].legend()
    
    _, lines = cdfplot(train, val, ax=axs[1])
    for idx, line in enumerate(lines):
        if idx < len(labels):
            line.set_label(labels[idx])
    
    axs[1].set_xlabel('Values')
    axs[1].set_ylabel('Cumulative Probability')
    axs[1].set_title('CDF Plot')
    axs[1].legend()
    
    _, lines = densityplot(train, val, ax=axs[2])
    for idx, line in enumerate(lines):
        if idx < len(labels):
            line.set_label(labels[idx])
    
    axs[2].set_xlabel('Values')
    axs[2].set_ylabel('Density')
    axs[2].set_title('Density Plot')
    axs[2].legend()
    
    return fig, axs
