import argparse
import time
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import sys
import gc
from collections import defaultdict
sys.path.append('..')
sys.path.append('../utils/SGW/lib')

from QDOT.QDOT_numpy import *
from risgw import risgw_gpu
from sgw_numpy import sgw_cpu


def _generate_data(n, d=3, seed=None):
    rng = np.random.default_rng(seed)
    X = rng.normal(size=(n, d))
    Y = rng.normal(size=(n, d))
    return X, Y

def time_once(X, Y, method='QDOT', rep_dim=100):
    t0 = time.perf_counter()
    Geo_Compare(X, Y, method, rep_dim)
    t1 = time.perf_counter()
    return t1 - t0

def summarize(df: pd.DataFrame) -> pd.DataFrame:
    grp = df.groupby(['method', 'n'])['seconds']
    summary = pd.DataFrame({
        'mean': grp.mean(),
        'std': grp.std(ddof=1)
    }).reset_index()
    return summary

def Geo_Compare(X, Y, method = 'QDOT', rep_dim = 100, return_coupling = False):
    if(method == 'QDOT'):
        loss, P = QDOT(X, Y, n_quantile = rep_dim, EMD = True, initial = False, sigma = 1000)
    elif(method == 'IQDOT'):
        loss = QDOT(X, Y, n_quantile = rep_dim, intergal = True, initial = False, sigma = 1000)
    elif(method == 'GW'):
        C1 = sp.spatial.distance.cdist(X, X)
        C2 = sp.spatial.distance.cdist(Y, Y)
        C1 /= C1.max()
        C2 /= C2.max()
        p = ot.unif(len(X))
        q = ot.unif(len(Y))
        P, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=False, log=True)
        loss = log['gw_dist']
    elif(method == 'EGW'):
        C1 = sp.spatial.distance.cdist(X, X)
        C2 = sp.spatial.distance.cdist(Y, Y)
        C1 /= C1.max()
        C2 /= C2.max()
        p = ot.unif(len(X))
        q = ot.unif(len(Y))
        P, log = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=0.1, log=True)
        loss = log['gw_dist']
    elif(method == 'SGW'):
        loss = sgw_cpu(X, Y, nproj = rep_dim)
    elif(method == 'RISGW'):
        device = torch.device('cpu')
        loss = risgw_gpu(torch.from_numpy(X).to(torch.float32).to(device),
                         torch.from_numpy(Y).to(torch.float32).to(device),
                         device,nproj=rep_dim, max_iter=100, lr=0.01)
        P = 0
    else:
        raise ValueError('Method not recognized. Choose from QDOT, IQDOT, GW, EGW, RISGW.')
    
    if(return_coupling):
        return loss, P
    else:
        return loss

def parse_method_token(token: str):
    base, last = token.rsplit('-', 1)
    if last.isdigit():
        return token, base, int(last)
    else:
        return token, token, None

def summarize(df: pd.DataFrame) -> pd.DataFrame:
    grp = df.groupby(['method', 'n'])['seconds']
    summary = pd.DataFrame({
        'mean': grp.mean(),
        'std': grp.std(ddof=1),
    }).reset_index()
    return summary

def _warm_call(X, Y, method='QDOT', rep_dim=100, warmup=3, disable_gc=True):
    if warmup <= 0:
        return
    gc_was_enabled = gc.isenabled()
    if disable_gc:
        gc.disable()
    try:
        for _ in range(warmup):
            Geo_Compare(X, Y, method, rep_dim)
    finally:
        if disable_gc and gc_was_enabled:
            gc.enable()

def main():
    ap = argparse.ArgumentParser(description='CPU-only benchmarking (METHOD-<dim>) per-method outputs')
    ap.add_argument('--sizes',   type=int, nargs='*', default=[100, 1000])
    ap.add_argument('--methods', type=str, nargs='*',
                    default=['SGW-50'])
    ap.add_argument('--repeats', type=int, default=10)
    ap.add_argument('--outdir',  type=str, default='./benchmark_out')
    ap.add_argument('--warmup',  type=int, default=1, help='number of warmup calls before timing')
    args = ap.parse_args()

    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
    parsed = [parse_method_token(tok) for tok in args.methods]

    by_label = defaultdict(list)

    DEFAULT_REP = 100
    for n in args.sizes:
        X, Y = _generate_data(n, d=2, seed=1234 + n)
        for label, base, dim_override in parsed:
            rep = dim_override if dim_override is not None else DEFAULT_REP
            
            _warm_call(X, Y, method=base, rep_dim=rep, warmup=args.warmup)
            for t in range(args.repeats):
                secs = time_once(X, Y, method=base, rep_dim=rep)
                by_label[label].append({
                    'method': label, 'n': int(n), 'trial': t, 'seconds': float(secs),
                    'rep_dim': rep, 'base': base
                })
                print(f"[n={n:>6}] {label:<12} rep_dim={rep:<4} trial {t+1}/{args.repeats}: {secs:.6f}s")

    for label, recs in by_label.items():
        mdir = outdir / label
        mdir.mkdir(parents=True, exist_ok=True)

        raw_df = pd.DataFrame.from_records(recs)
        raw_df.to_csv(mdir / 'raw.csv', index=False)

        sum_df = summarize(raw_df[['method','n','seconds']])
        sum_df.to_csv(mdir / 'summary.csv', index=False)

        print(f"Saved -> { (mdir / 'raw.csv').resolve() }")
        print(f"Saved -> { (mdir / 'summary.csv').resolve() }")

if __name__ == '__main__':
    main()
