# Author: Hassan Ismail Fawaz <hassan.ismail-fawaz@uha.fr>
#         Germain Forestier <germain.forestier@uha.fr>
#         Jonathan Weber <jonathan.weber@uha.fr>
#         Lhassane Idoumghar <lhassane.idoumghar@uha.fr>
#         Pierre-Alain Muller <pierre-alain.muller@uha.fr>
# License: GPL3

import numpy as np
import pandas as pd
import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt

matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = 'DejaVu Sans'

import operator
import math
from scipy.stats import wilcoxon
from scipy.stats import friedmanchisquare
import networkx
import json
from math import sqrt
import os
import glob

# Model categorization
TREE_MODELS = ['catboost', 'lightgbm', 'xgboost']
BASELINE_MODELS = ['decisiontree', 'knn', 'linearmodel', 'randomforest', 'svm']
# TABPFN_MODELS = ['tabpfn']
NEURAL_MODELS = ['danet', 'imlp', 'mlp', 'resnet', 'saint', 'stg', 'tabnet', 'vime', 'modernnca', 'realmlp', 'tabm', 'tabpfnv2', 'tabr']

# Metrics where lower is better
LOWER_IS_BETTER = ['log_loss', 'time', 'energy']

# --- fix get_model_prefix (remove TABPFN_MODELS reference) ---
def get_model_prefix(model_name):
    if model_name in TREE_MODELS:
        return f"[T] {model_name}"
    if model_name in BASELINE_MODELS:
        return f"[B] {model_name}"
    if model_name in NEURAL_MODELS:
        return f"[N] {model_name}"
    return model_name


def is_lower_better(csv_filename):
    """Determine if lower values are better based on the filename"""
    filename_lower = csv_filename.lower()
    for metric in LOWER_IS_BETTER:
        if metric in filename_lower:
            return True
    return False

# inspired from orange3 https://docs.orange.biolab.si/3/data-mining-library/reference/evaluation.cd.html
def graph_ranks(avranks, names, p_values, cd=None, cdmethod=None, lowv=None, highv=None,
                width=6, textspace=1, reverse=False, filename=None, labels=False, **kwargs):
    """
    Draws a CD graph, which is used to display the differences in methods'
    performance. See Janez Demsar, Statistical Comparisons of Classifiers over
    Multiple Data Sets, 7(Jan):1--30, 2006.

    Needs matplotlib to work.

    The image is ploted on `plt` imported using
    `import matplotlib.pyplot as plt`.

    Args:
        avranks (list of float): average ranks of methods.
        names (list of str): names of methods.
        cd (float): Critical difference used for statistically significance of
            difference between methods.
        cdmethod (int, optional): the method that is compared with other methods
            If omitted, show pairwise comparison of methods
        lowv (int, optional): the lowest shown rank
        highv (int, optional): the highest shown rank
        width (int, optional): default width in inches (default: 6)
        textspace (int, optional): space on figure sides (in inches) for the
            method names (default: 1)
        reverse (bool, optional):  if set to `True`, the lowest rank is on the
            right (default: `False`)
        filename (str, optional): output file name (with extension). If not
            given, the function does not write a file.
        labels (bool, optional): if set to `True`, the calculated avg rank
        values will be displayed
    """
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        from matplotlib.backends.backend_agg import FigureCanvasAgg
    except ImportError:
        raise ImportError("Function graph_ranks requires matplotlib.")

    width = float(width)
    textspace = float(textspace)

    def nth(l, n):
        """
        Returns only nth elemnt in a list.
        """
        n = lloc(l, n)
        return [a[n] for a in l]

    def lloc(l, n):
        """
        List location in list of list structure.
        Enable the use of negative locations:
        -1 is the last element, -2 second last...
        """
        if n < 0:
            return len(l[0]) + n
        else:
            return n

    def mxrange(lr):
        """
        Multiple xranges. Can be used to traverse matrices.
        This function is very slow due to unknown number of
        parameters.

        >>> mxrange([3,5])
        [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]

        >>> mxrange([[3,5,1],[9,0,-3]])
        [(3, 9), (3, 6), (3, 3), (4, 9), (4, 6), (4, 3)]

        """
        if not len(lr):
            yield ()
        else:
            # it can work with single numbers
            index = lr[0]
            if isinstance(index, int):
                index = [index]
            for a in range(*index):
                for b in mxrange(lr[1:]):
                    yield tuple([a] + list(b))

    def print_figure(fig, *args, **kwargs):
        canvas = FigureCanvasAgg(fig)
        canvas.print_figure(*args, **kwargs)

    sums = avranks

    nnames = names
    ssums = sums

    if lowv is None:
        lowv = min(1, int(math.floor(min(ssums))))
    if highv is None:
        highv = max(len(avranks), int(math.ceil(max(ssums))))

    cline = 0.4

    k = len(sums)

    lines = None

    linesblank = 0
    scalewidth = width - 2 * textspace

    def rankpos(rank):
        if not reverse:
            a = rank - lowv
        else:
            a = highv - rank
        return textspace + scalewidth / (highv - lowv) * a

    distanceh = 0.25

    cline += distanceh

    # calculate height needed height of an image
    minnotsignificant = max(2 * 0.2, linesblank)
    height = cline + ((k + 1) / 2) * 0.2 + minnotsignificant

    fig = plt.figure(figsize=(width, height))
    fig.set_facecolor('white')
    ax = fig.add_axes([0, 0, 1, 1])  # reverse y axis
    ax.set_axis_off()

    hf = 1. / height  # height factor
    wf = 1. / width

    def hfl(l):
        return [a * hf for a in l]

    def wfl(l):
        return [a * wf for a in l]

    # Upper left corner is (0,0).
    ax.plot([0, 1], [0, 1], c="w")
    ax.set_xlim(0, 1)
    ax.set_ylim(1, 0)

    def line(l, color='k', **kwargs):
        """
        Input is a list of pairs of points.
        """
        ax.plot(wfl(nth(l, 0)), hfl(nth(l, 1)), color=color, **kwargs)

    def text(x, y, s, *args, **kwargs):
        ax.text(wf * x, hf * y, s, *args, **kwargs)

    line([(textspace, cline), (width - textspace, cline)], linewidth=2)

    bigtick = 0.3
    smalltick = 0.15
    linewidth = 2.0
    linewidth_sign = 4.0

    tick = None
    for a in list(np.arange(lowv, highv, 0.5)) + [highv]:
        tick = smalltick
        if a == int(a):
            tick = bigtick
        line([(rankpos(a), cline - tick / 2),
              (rankpos(a), cline)],
             linewidth=2)

    for a in range(lowv, highv + 1):
        text(rankpos(a), cline - tick / 2 - 0.05, str(a),
             ha="center", va="bottom", size=16)

    k = len(ssums)

    def filter_names(name):
        return get_model_prefix(name)

    space_between_names = 0.24

    for i in range(math.ceil(k / 2)):
        chei = cline + minnotsignificant + i * space_between_names
        line([(rankpos(ssums[i]), cline),
              (rankpos(ssums[i]), chei),
              (textspace - 0.1, chei)],
             linewidth=linewidth)
        if labels:
            text(textspace + 0.15, chei - 0.075, format(ssums[i], '.2f'), ha="right", va="center", size=10)
        text(textspace - 0.2, chei, filter_names(nnames[i]), ha="right", va="center", size=16)

    for i in range(math.ceil(k / 2), k):
        chei = cline + minnotsignificant + (k - i - 1) * space_between_names
        line([(rankpos(ssums[i]), cline),
              (rankpos(ssums[i]), chei),
              (textspace + scalewidth + 0.3, chei)],
             linewidth=linewidth)
        if labels:
            text(textspace + scalewidth - 0.05, chei - 0.075, format(ssums[i], '.2f'), ha="left", va="center", size=10)
        text(textspace + scalewidth + 0.4, chei, filter_names(nnames[i]),
             ha="left", va="center", size=16)

    # no-significance lines
    def draw_lines(lines, side=0.05, height=0.1):
        start = cline + 0.2

        for l, r in lines:
            line([(rankpos(ssums[l]) - side, start),
                  (rankpos(ssums[r]) + side, start)],
                 linewidth=linewidth_sign)
            start += height
            print('drawing: ', l, r)

    # draw_lines(lines)
    start = cline + 0.2
    side = -0.02
    height = 0.1

    # draw no significant lines
    # get the cliques
    cliques = form_cliques(p_values, nnames)
    i = 1
    achieved_half = False
    print(nnames)
    for clq in cliques:
        if len(clq) == 1:
            continue
        print(clq)
        min_idx = np.array(clq).min()
        max_idx = np.array(clq).max()
        if min_idx >= len(nnames) / 2 and achieved_half == False:
            start = cline + 0.25
            achieved_half = True
        line([(rankpos(ssums[min_idx]) - side, start),
              (rankpos(ssums[max_idx]) + side, start)],
             linewidth=linewidth_sign)
        start += height


def form_cliques(p_values, nnames):
    """
    This method forms the cliques
    """
    # first form the numpy matrix data
    m = len(nnames)
    g_data = np.zeros((m, m), dtype=np.int64)
    for p in p_values:
        if p[3] == False:
            i = np.where(nnames == p[0])[0][0]
            j = np.where(nnames == p[1])[0][0]
            min_i = min(i, j)
            max_j = max(i, j)
            g_data[min_i, max_j] = 1

    g = networkx.Graph(g_data)
    return networkx.find_cliques(g)

# --- add near the top (after imports and NEURAL_MODELS) ---
def filter_subset(df_perf, subset):
    if subset == "all":
        return df_perf.copy()
    if subset == "neural":
        keep = df_perf["classifier_name"].str.lower().isin(NEURAL_MODELS)
        out = df_perf.loc[keep].copy()
        # need at least 3 classifiers for CD diagram
        n_cls = out["classifier_name"].nunique()
        if n_cls < 3:
            raise ValueError(f"Not enough neural classifiers after filtering (got {n_cls}).")
        return out
    raise ValueError(f"Unknown subset: {subset}")

def suffix_name(path, suffix):
    base, ext = os.path.splitext(path)
    return f"{base}{suffix}{ext}"


# ------------------- config -------------------
TREE_MODELS = ['catboost', 'lightgbm', 'xgboost']
BASELINE_MODELS = ['decisiontree', 'knn', 'linearmodel', 'randomforest', 'svm']
NEURAL_MODELS = [
    'danet','imlp','mlp','resnet','saint','stg','tabnet','vime',
    'modernnca','realmlp','tabm','tabpfnv2','tabr'
]
LOWER_IS_BETTER_TOKENS = ['log_loss', 'time', 'energy']  # matched on filename

def get_model_prefix(model_name: str) -> str:
    if model_name in TREE_MODELS:    return f"[T] {model_name}"
    if model_name in BASELINE_MODELS:return f"[B] {model_name}"
    if model_name in NEURAL_MODELS:  return f"[N] {model_name}"
    return model_name

def is_lower_better_from_filename(csv_filename: str) -> bool:
    fn = os.path.basename(csv_filename).lower()
    return any(tok in fn for tok in LOWER_IS_BETTER_TOKENS)

def _detect_value_column(df: pd.DataFrame) -> str:
    # Prefer common long-format names
    for cand in ['value','score','metric','accuracy','log_loss','time','energy']:
        if cand in df.columns:
            return cand
    # Fallback: last numeric column
    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    if not num_cols:
        raise ValueError("No numeric value column found")
    return num_cols[-1]

def wilcoxon_holm(df_perf: pd.DataFrame, csv_filename: str, alpha: float = 0.05):
    # expected identifier columns
    for col in ['classifier_name','dataset_name']:
        if col not in df_perf.columns:
            raise ValueError(f"Missing required column '{col}' in {csv_filename}")

    val_col = _detect_value_column(df_perf)
    ascending = is_lower_better_from_filename(csv_filename)

    # keep only classifiers present on all datasets
    counts = df_perf.groupby('classifier_name').size()
    max_n = counts.max()
    classifiers = counts[counts == max_n].index.tolist()

    # Friedman test
    wide = (df_perf[df_perf['classifier_name'].isin(classifiers)]
            .pivot(index='classifier_name', columns='dataset_name', values=val_col)
            .sort_index())
    # require complete cases
    wide = wide.dropna(axis=1)
    if wide.shape[1] < 2:
        raise ValueError("Not enough complete datasets for Friedman")

    stat, pval = friedmanchisquare(*[wide.loc[c].values for c in wide.index])
    if pval >= alpha:
        return dict(skip=True, friedman=dict(
            n_datasets=int(wide.shape[1]),
            n_classifiers=int(wide.shape[0]),
            friedman_chi2=float(stat),
            friedman_p=float(pval),
            CD=None
        ))

    # ranks
    ranks = wide.rank(axis=0, ascending=ascending)
    avg_ranks = ranks.mean(axis=1).sort_values()

    # pairwise Wilcoxon with Holm
    pairs = []
    vals = {c: wide.loc[c].values for c in wide.index}
    for i in range(len(classifiers)):
        for j in range(i+1, len(classifiers)):
            c1, c2 = classifiers[i], classifiers[j]
            p = wilcoxon(vals[c1], vals[c2], zero_method='pratt')[1]
            pairs.append((c1, c2, p, False))
    pairs.sort(key=operator.itemgetter(2))
    m = len(pairs)
    for i,(c1,c2,p,_) in enumerate(pairs):
        alpha_i = alpha/(m - i)
        pairs[i] = (c1,c2,p, p <= alpha_i)

    # CD
    k = len(classifiers); N = wide.shape[1]
    q_alpha = {3:2.343,4:2.569,5:2.728,6:2.850,7:2.949,8:3.031,9:3.102,10:3.164,
               11:3.219,12:3.268,13:3.312,14:3.352,15:3.389,16:3.422,17:3.453,
               18:3.481,19:3.507,20:3.532}.get(k, None)
    CD = q_alpha * sqrt(k*(k+1)/(6.0*N)) if q_alpha is not None else None

    return dict(
        skip=False,
        avg_ranks=avg_ranks,
        pairs=pairs,
        friedman=dict(
            n_datasets=int(N),
            n_classifiers=int(k),
            friedman_chi2=float(stat),
            friedman_p=float(pval),
            CD=float(CD) if CD is not None else None,
            mean=wide.mean(axis=1).to_dict(),
            std=wide.std(axis=1, ddof=1).to_dict()
        )
    )

# --- replace draw_cd_diagram with dict-aware version ---
def draw_cd_diagram(df_perf=None, alpha=0.05, title=None, labels=False, csv_filename=None, subset="all"):
    df_sub = filter_subset(df_perf, subset)

    res = wilcoxon_holm(df_perf=df_sub, csv_filename=csv_filename, alpha=alpha)
    if res.get("skip", False):
        print(f"Friedman non-significant for {csv_filename}. Skipping CD.")
        return

    avg_ranks = res["avg_ranks"]
    p_values  = res["pairs"]
    friedman  = res["friedman"]

    graph_ranks(avg_ranks.values, avg_ranks.index.values, p_values,
                cd=None, reverse=False, width=9, textspace=1.5, labels=True)

    if title:
        plt.title(title + ("" if subset=="all" else " (Neural only)"),
                  fontdict={'family':'sans-serif','color':'black','size':20}, y=0.95, x=0.5)

    out_pdf   = csv_filename.replace('_cd_input_long.csv', '_cd_diagram.pdf')
    out_stats = csv_filename.replace('_cd_input_long.csv', '_stats.json')
    out_ranks = csv_filename.replace('_cd_input_long.csv', '_avg_ranks.csv')
    out_pairs = csv_filename.replace('_cd_input_long.csv', '_pairwise.csv')

    if subset == "neural":
        out_pdf   = suffix_name(out_pdf,   "_nn")
        out_stats = suffix_name(out_stats, "_nn")
        out_ranks = suffix_name(out_ranks, "_nn")
        out_pairs = suffix_name(out_pairs, "_nn")

    plt.savefig(out_pdf, bbox_inches='tight'); plt.close()
    with open(out_stats, 'w') as f: json.dump(friedman, f, indent=2)
    avg_ranks.to_csv(out_ranks, header=['avg_rank'])
    pd.DataFrame(p_values, columns=['clf1','clf2','p','significant']).to_csv(out_pairs, index=False)



def _find_csvs():
    globs = [
        os.path.join('.', 'statistical_test_data', '*_cd_input_long.csv'),
        os.path.join('..', '*_cd_input_long.csv'),
        os.path.join('.', '*_cd_input_long.csv'),
    ]
    found = []
    for pat in globs:
        found.extend(glob.glob(pat))
    return sorted(set(found))

# --- replace main() block call sites (use df, not df_perf) ---
def main():
    import sys
    if len(sys.argv) > 1:
        files = [sys.argv[1]]
    else:
        files = _find_csvs()
    if not files:
        print("No CSV files found. Expected '*_cd_input_long.csv'.")
        return
    print(f"Found {len(files)} CSV files")
    for csv_file in files:
        try:
            df = pd.read_csv(csv_file)
            metric_name = os.path.basename(csv_file).replace('_cd_input_long.csv','').replace('_',' ').title()
            print(f"Processing {csv_file} ...")
            draw_cd_diagram(df_perf=df, title=metric_name, labels=True, csv_filename=csv_file, subset="all")
            draw_cd_diagram(df_perf=df, title=metric_name, labels=True, csv_filename=csv_file, subset="neural")
        except Exception as e:
            print(f"Error processing {csv_file}: {e}")


if __name__ == "__main__":
    main()
