#!/usr/bin/env python
# experiment.py - Main experiment script
# --------------------------------------------------------------------
import os
import argparse
import random
import numpy as np
import torch
import pandas as pd
from pathlib import Path

from utils_data import load_and_preprocess_data
from models import CLMFixed, CLMLearnable
from training import train_clm_explicit

# Random seed
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cpu")

def main(dataset: str, fold: int, link: str = "logit"):
    print(f"Using link function: {link}")

    if link == "logit":
        lr = 1e-2
    elif link == "probit":
        if dataset in ["LEV", "SWD", "car"]:
            lr = 5e-3
        elif dataset in ["ERA", "winequality-red"]:
            lr = 1e-3
        else:
            lr = 1e-2
    else:
        lr = 1e-2

    print(f"Using learning rate: {lr}")
    
    ROOT = Path("../datasets-orreview/ordinal-regression")
    
    RES_DIR = Path("results") / dataset
    RES_DIR.mkdir(parents=True, exist_ok=True)

    print(f"Loading and preprocessing dataset: {dataset}, fold: {fold}")
    Xtr, y_tr, Xte, y_te = load_and_preprocess_data(ROOT, dataset, fold)
    print(f"Data shapes - Xtr: {Xtr.shape}, Xte: {Xte.shape}")
    print(f"Classes - train: {sorted(set(y_tr))}, test: {sorted(set(y_te))}")

    K = len(np.unique(y_tr))
    print(f"Number of classes: {K}")

    print(f"\n{'='*50}\nTraining Fixed threshold model with {link} link (lr={lr})\n{'='*50}")
    (netF, lF,
    nllF,
    Xtr_F, ytr_F,
    tr_nll_F, te_nll_F,
    onc1_tr_F, onc21_tr_F, onc22_tr_F, onc3_tr_F,
    onc1_te_F, onc21_te_F, onc22_te_F, onc3_te_F,
    tr_mae_F, te_mae_F,
    tr_acc_F, te_acc_F,
    tr_within1_F, te_within1_F,
    tr_min_sens_F, te_min_sens_F,
    tr_qwk_F, te_qwk_F,
    ) = train_clm_explicit(
        Xtr, y_tr, Xte, y_te,
        CLMFixed(K, link=link),
        tag=f"{dataset}_{link}_fix_f{fold}",
        device=DEVICE,
        lr=lr
    )
    
    print(f"\n{'='*50}\nTraining Learnable threshold model with {link} link (lr={lr})\n{'='*50}")
    (netL, lL,
    nllL,
    Xtr_L, ytr_L,
    tr_nll_L, te_nll_L,
    onc1_tr_L, onc21_tr_L, onc22_tr_L, onc3_tr_L,
    onc1_te_L, onc21_te_L, onc22_te_L, onc3_te_L,
    tr_mae_L, te_mae_L,
    tr_acc_L, te_acc_L,
    tr_within1_L, te_within1_L,
    tr_min_sens_L, te_min_sens_L,
    tr_qwk_L, te_qwk_L,
    ) = train_clm_explicit(
        Xtr, y_tr, Xte, y_te,
        CLMLearnable(K, link=link),
        tag=f"{dataset}_{link}_learn_f{fold}",
        device=DEVICE,
        lr=lr
    )
    
    df_hist_fix = pd.DataFrame({
        'epoch':    np.arange(1, len(tr_nll_F)+1),
        'tr_nll':   tr_nll_F,
        'te_nll':   te_nll_F,
        'tr_mae':   tr_mae_F,
        'te_mae':   te_mae_F,
        'onc1_tr':  onc1_tr_F,
        'onc1_te':  onc1_te_F,
        'onc21_tr': onc21_tr_F,
        'onc21_te': onc21_te_F,
        'onc22_tr': onc22_tr_F,
        'onc22_te': onc22_te_F,
        'onc3_tr':  onc3_tr_F,
        'onc3_te':  onc3_te_F,
        'tr_acc':   tr_acc_F,
        'te_acc':   te_acc_F,
        'tr_within1': tr_within1_F,
        'te_within1': te_within1_F,
        'tr_min_sens': tr_min_sens_F,
        'te_min_sens': te_min_sens_F,
        'tr_qwk':   tr_qwk_F,
        'te_qwk':   te_qwk_F,
    })
        
    df_hist_fix.to_csv(RES_DIR / f"fold{fold}_history_{link}_fix.csv", index=False) 
    
    df_hist_learn = pd.DataFrame({
        'epoch':    np.arange(1, len(tr_nll_L)+1),
        'tr_nll':   tr_nll_L,
        'te_nll':   te_nll_L,
        'tr_mae':   tr_mae_L,
        'te_mae':   te_mae_L,
        'onc1_tr':  onc1_tr_L,
        'onc1_te':  onc1_te_L,
        'onc21_tr': onc21_tr_L,
        'onc21_te': onc21_te_L,
        'onc22_tr': onc22_tr_L,
        'onc22_te': onc22_te_L,
        'onc3_tr':  onc3_tr_L,
        'onc3_te':  onc3_te_L,
        'tr_acc':   tr_acc_L,
        'te_acc':   te_acc_L,
        'tr_within1': tr_within1_L,
        'te_within1': te_within1_L,
        'tr_min_sens': tr_min_sens_L,
        'te_min_sens': te_min_sens_L,
        'tr_qwk':   tr_qwk_L,
        'te_qwk':   te_qwk_L,
    })
        
    df_hist_learn.to_csv(RES_DIR / f"fold{fold}_history_{link}_learn.csv", index=False) 
    
    summary = {
        'dataset': dataset,
        'fold': fold,
        'link': link,
        'NLL_fix': nllF,
        'MAE_fix': te_mae_F[-1],
        'NLL_learn': nllL,
        'MAE_learn': te_mae_L[-1],
        'ONC1_fix': onc1_te_F[-1],
        'ONC1_learn': onc1_te_L[-1],
        'ONC21_fix': onc21_te_F[-1],
        'ONC21_learn': onc21_te_L[-1],
        'ONC22_fix': onc22_te_F[-1],
        'ONC22_learn': onc22_te_L[-1],
        'ONC3_fix': onc3_te_F[-1],
        'ONC3_learn': onc3_te_L[-1],
        'ACC_fix': te_acc_F[-1],
        'ACC_learn': te_acc_L[-1],
        'Within1_fix': te_within1_F[-1],
        'Within1_learn': te_within1_L[-1],
        'MinSens_fix': te_min_sens_F[-1],
        'MinSens_learn': te_min_sens_L[-1],
        'QWK_fix': te_qwk_F[-1],
        'QWK_learn': te_qwk_L[-1],
    }
        
    pd.DataFrame([summary]).to_csv(RES_DIR / f"fold{fold}_{link}_summary.csv", index=False)
    
    print(f"\nExperiment completed for dataset: {dataset}, fold: {fold}, link: {link}")
    print(f"Results saved to: {RES_DIR}")

    return summary

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run ordinal regression experiment")
    parser.add_argument("--dataset", required=True, help="Dataset name (e.g. ERA)")
    parser.add_argument("--fold", type=int, required=True, help="Fold number 0..29")
    parser.add_argument("--link", type=str, default="logit", choices=["logit", "probit"], help="Link function type: logit, probit")
    args = parser.parse_args()
    
    main(args.dataset, args.fold, args.link)