#!/usr/bin/env python
# experiment_vis.py
# --------------------------------------------------------------------
import os
import argparse
import random
import numpy as np
import torch
from pathlib import Path
import time
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
from sklearn.decomposition import PCA
from matplotlib.ticker import MaxNLocator

from utils_data import load_and_preprocess_data
from models import CLMFixed, CLMLearnable, ResMLP

SEED = 42

def set_all_seeds(seed=SEED):
    """Reset all random seeds to ensure reproducibility"""
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed at module import time
set_all_seeds(SEED)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cpu")

def lighten_color(color, amount=0.5):
    c = np.array(mcolors.to_rgb(color))
    return tuple(c + (1.0 - c) * amount)

def create_legend_file(dataset, tag, output_dir="vis", num_classes=6):    

    base_colors = plt.cm.rainbow(np.linspace(0, 1, num_classes))
    
    fig, ax = plt.subplots(figsize=(14, 1.5))
    ax.set_axis_off() 
    
    handles = []
    labels = []
    
    for i in range(num_classes):
        handles.append(Line2D([0], [0], marker='o', color='w', markerfacecolor=base_colors[i], markersize=12))
        labels.append(f"Class {i+1}")
    
    handles.append(Line2D([0], [0], marker='*', color='gold', markeredgecolor='k', markersize=15))
    labels.append("Class Mean")
    
    handles.append(Line2D([0], [0], marker='>', color='k', markeredgecolor='k', markersize=12))
    labels.append("Classifier")
    
    legend = ax.legend(handles, labels, loc='center', ncol=num_classes+2, frameon=True, fontsize=14, handletextpad=0.5, columnspacing=1.0)
    
    output_path = Path(output_dir) / f"{tag}_legend.png"
    plt.tight_layout(pad=1.0)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Legend file saved: {output_path}")
    
    return output_path

def create_thresholds_legend_file(output_dir="vis", tag="thresholds"):

    fig, ax = plt.subplots(figsize=(6, 1.5))
    ax.set_axis_off()  
    
    handles = [Line2D([0], [0], color='r', linestyle='--', linewidth=2)]
    labels = ["Thresholds"]
    
    legend = ax.legend(handles, labels, loc='center', frameon=True, fontsize=14)
    
    output_path = Path(output_dir) / f"{tag}_legend.png"
    plt.tight_layout(pad=1.0)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Thresholds legend file saved: {output_path}")
    
    return output_path

def plot_feat_custom(net, X_tr, y_tr, X_val=None, y_val=None, fname="vis/feat.png", epoch=0):

    net.eval()

    with torch.no_grad():
        feats_tr = net.f(torch.tensor(X_tr, device=net.h.weight.device)).cpu().numpy()

    pca = PCA(n_components=2).fit(feats_tr)
    F2d_tr = pca.transform(feats_tr)  # (N_tr, 2)

    if X_val is not None and y_val is not None:
        with torch.no_grad():
            feats_val = net.f(torch.tensor(X_val, device=net.h.weight.device)).cpu().numpy()
        F2d_val = pca.transform(feats_val)  # (N_val, 2)

    w_full = net.h.weight.detach().cpu().numpy().reshape(-1)  # (d,)
    w_centered = w_full - pca.mean_                          # (d,)
    w2d = pca.components_.dot(w_centered)                    # (2,)

    classes = np.unique(y_tr)
    num_classes = len(classes)
    base_colors = plt.cm.rainbow(np.linspace(0, 1, num_classes))
    color_map = {cls: base_colors[i] for i, cls in enumerate(classes)}
    light_map = {cls: lighten_color(color_map[cls], amount=0.6) for cls in classes}

    fig, ax = plt.subplots(figsize=(4, 2), dpi=300)

    for cls in classes:
        mask = (y_tr == cls)
        ax.scatter(F2d_tr[mask,0], F2d_tr[mask,1], color=color_map[cls], s=40, alpha=0.9, edgecolors='none')

    if X_val is not None and y_val is not None:
        for cls in classes:
            mask = (y_val == cls)
            ax.scatter(F2d_val[mask,0], F2d_val[mask,1], color=light_map[cls], s=20, alpha=0.6, edgecolors='none')

    for cls in classes:
        center = F2d_tr[y_tr == cls].mean(axis=0)
        ax.scatter(center[0], center[1],marker='*', s=200, color=color_map[cls], edgecolor='k')

    ax.quiver(0, 0, w2d[0], w2d[1], angles='xy', scale_units='xy', scale=1, color='k', width=0.008)

    xmin, xmax = ax.get_xlim()
    x_range = xmax - xmin

    fig_width, fig_height = 4, 2 
    y_range = x_range * (fig_height / fig_width) 

    current_y_min, current_y_max = ax.get_ylim()
    y_center = (current_y_min + current_y_max) / 2
    y_min = y_center - y_range / 2
    y_max = y_center + y_range / 2
    ax.set_ylim(y_min, y_max)

    ax.set_xlabel('PC1', fontsize=11)
    ax.set_ylabel('PC2', fontsize=11)

    ax.set_title(f'Feature Space (Epoch {epoch})', fontsize=12, fontweight='bold')

    ax.grid(True, linestyle='--', alpha=0.3, linewidth=0.5)
    
    plt.tight_layout(pad=0.5)
    
    Path(fname).parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(fname, dpi=300, bbox_inches='tight')
    plt.close()

def plot_z_thresholds(net, loss_layer, X_tr, y_tr, X_val, y_val, fname, device, epoch=0, link="logit"):

    net.eval()
    with torch.no_grad():

        zs_tr = net(torch.tensor(X_tr, device=device)).detach().cpu().numpy().flatten()
        
        zs_val = net(torch.tensor(X_val, device=device)).detach().cpu().numpy().flatten()
        
        if not hasattr(loss_layer, '_b'):
            b_eff = loss_layer.b.cpu().numpy()
        else:
            b_eff = loss_layer._b().cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(4, 2), dpi=300)
    
    classes = np.unique(y_tr)
    base_colors = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
    color_map = {cls: base_colors[i] for i, cls in enumerate(classes)}
    light_map = {cls: lighten_color(color_map[cls], amount=0.6) for cls in classes}
    
    for cls in classes:
        mask = (y_tr == cls)
        ax.scatter(zs_tr[mask], y_tr[mask] + 1, color=color_map[cls], s=12, alpha=0.9, edgecolors='none')

    if X_val is not None and y_val is not None:
        for cls in classes:
            mask = (y_val == cls)
            ax.scatter(zs_val[mask], y_val[mask] + 1, color=light_map[cls], s=6, alpha=0.6, edgecolors='none')
        
    for t in b_eff:
        ax.axvline(t, color='r', linestyle='--', linewidth=1.5)
    
    ax.set_xlabel(r'Latent Variable $z=\boldsymbol{w}^{\top} \boldsymbol{h}_\theta(\boldsymbol{x})$', fontsize=11)
    ax.set_ylabel('Class', fontsize=11)
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_title(f'Latent vs Thresholds (Epoch {epoch})', fontsize=12, fontweight='bold')
    
    plt.tight_layout(pad=0.5)
    
    Path(fname).parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(fname, dpi=300, bbox_inches='tight')
    plt.close()

def should_visualize(epoch):
    if epoch <= 100:
        return True
    elif epoch <= 1000 and epoch % 100 == 0:
        return True
    elif epoch > 1000 and epoch % 1000 == 0:
        return True
    return False

def train_and_visualize(Xtr, ytr, Xte, yte, loss_layer, tag, device, epochs=5000, output_dir="vis", link="logit", reset_seed=True):
    # IMPORTANT: Reset seed before creating network
    if reset_seed:
        print(f"Resetting random seed to {SEED} before creating network...")
        set_all_seeds(SEED)

    start_time = time.time()

    dataset = tag.split('_')[0]
    
    if link == "logit":
        lr = 1e-2
    elif link == "probit":
        if dataset in ["LEV", "SWD", "car"]:
            lr = 5e-3
        elif dataset in ["ERA", "winequality-red"]:
            lr = 1e-3
        else:
            lr = 1e-2
    else:
        lr = 1e-2
    
    print(f"Using learning rate: {lr} (dataset: {dataset}, link function: {link})")
    
    trL = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(Xtr), torch.tensor(ytr)), batch_size=2048, shuffle=True)
    
    net = ResMLP(Xtr.shape[1], width=128, depth=4).to(device)
    loss_layer = loss_layer.to(device)
    
    param_groups = [{'params': net.parameters(), 'weight_decay': 5e-3},]
    
    if hasattr(loss_layer, 'delta'):
        param_groups.append({'params': [loss_layer.delta], 'weight_decay': 0.0})
    
    opt = torch.optim.Adam(param_groups, lr=lr)
    
    scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[200, 800, 3000], gamma=0.1)
    
    feat_vis_dir = Path(output_dir) / f"{tag}_feat_epochs"
    z_vis_dir = Path(output_dir) / f"{tag}_z_epochs"
    feat_vis_dir.mkdir(parents=True, exist_ok=True)
    z_vis_dir.mkdir(parents=True, exist_ok=True)
    
    num_classes = len(np.unique(ytr))
    create_legend_file(dataset, tag, output_dir, num_classes)
    create_thresholds_legend_file(output_dir, f"{tag}_thresholds")
    
    print(f"Generating epoch 0 visualization → {feat_vis_dir}/epoch_0000.png")
    net.eval()
    plot_feat_custom(net, Xtr, ytr, Xte, yte, feat_vis_dir / "epoch_0000.png", epoch=0)
    plot_z_thresholds(net, loss_layer, Xtr, ytr, Xte, yte, z_vis_dir / "epoch_0000.png", device, epoch=0, link=link)
    
    for epoch in range(1, epochs+1):
        net.train()
        for xb, yb in trL:
            xb, yb = xb.to(device), yb.to(device)
            nll, *_ = loss_layer(net(xb), yb)
            opt.zero_grad()
            nll.backward()
            opt.step()
        scheduler.step()
        
        do_visualize = should_visualize(epoch)
        
        if do_visualize:
            net.eval()
            epoch_str = f"{epoch:04d}"
            feat_vis_path = feat_vis_dir / f"epoch_{epoch_str}.png"
            z_vis_path = z_vis_dir / f"epoch_{epoch_str}.png"
            plot_feat_custom(net, Xtr, ytr, Xte, yte, feat_vis_path, epoch=epoch)
            plot_z_thresholds(net, loss_layer, Xtr, ytr, Xte, yte, z_vis_path, device, epoch=epoch, link=link)
            
            elapsed = time.time() - start_time
            current_lr = scheduler.get_last_lr()[0]
            print(f"[{tag}] Epoch {epoch}/{epochs} | "
                  f"lr={current_lr:.1e} | "
                  f"time={elapsed:.1f}s | "
                  f"Generated visualizations → {feat_vis_path}, {z_vis_path}")
    
    print(f"Complete! All feature space visualizations saved to: {feat_vis_dir}")
    print(f"Complete! All latent variable visualizations saved to: {z_vis_dir}")
    
    return net.cpu(), loss_layer.cpu()

def process_dataset(dataset, fold=29, epochs=5000, root_dir="datasets-orreview/ordinal-regression", link="logit"):
    print(f"\n{'='*70}\nProcessing dataset: {dataset} (fold {fold}, link: {link})\n{'='*70}")
    
    VIS_DIR = Path("vis") / dataset
    VIS_DIR.mkdir(parents=True, exist_ok=True)

    try:
        print(f"Loading and preprocessing dataset: {dataset}, fold: {fold}")
        Xtr, y_tr, Xte, y_te = load_and_preprocess_data(Path(root_dir), dataset, fold)
        print(f"Data shapes - Xtr: {Xtr.shape}, Xte: {Xte.shape}")
        print(f"Classes - train: {sorted(set(y_tr))}, test: {sorted(set(y_te))}")

        K = len(np.unique(y_tr))
        print(f"Number of classes: {K}")

        print(f"\n{'-'*50}\nTraining Fixed threshold model (link: {link})\n{'-'*50}")
        netF, lF = train_and_visualize(
            Xtr, y_tr, Xte, y_te,
            CLMFixed(K, link=link),
            tag=f"{dataset}_{link}_fix_f{fold}",
            device=DEVICE,
            epochs=epochs,
            output_dir=str(VIS_DIR),
            link=link,
            reset_seed=True  # Reset seed for Fixed model
        )

        print(f"\n{'-'*50}\nTraining Learnable threshold model (link: {link})\n{'-'*50}")
        netL, lL = train_and_visualize(
            Xtr, y_tr, Xte, y_te,
            CLMLearnable(K, link=link),
            tag=f"{dataset}_{link}_learn_f{fold}",
            device=DEVICE,
            epochs=epochs,
            output_dir=str(VIS_DIR),
            link=link,
            reset_seed=True  # Reset seed for Learnable model
        )
        
        print(f"\nDataset {dataset} processing completed!")
        return True
    
    except Exception as e:
        print(f"Error processing dataset {dataset}: {e}")
        import traceback
        traceback.print_exc()
        return False
    
def get_available_datasets(root_dir):
    return ["ERA", "LEV", "SWD", "car", "winequality-red"]

def main(fold=29, epochs=5000, root_dir="datasets-orreview/ordinal-regression", specific_dataset=None, link="logit"):
    print(f"Using device: {DEVICE} (CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')})")
    print(f"Using link function: {link}")
    print(f"Random seed: {SEED}")
    
    if specific_dataset:
        datasets = [specific_dataset]
        if not (Path(root_dir) / specific_dataset).exists():
            print(f"Error: Dataset '{specific_dataset}' does not exist in {root_dir}!")
            return
    else:
        datasets = get_available_datasets(root_dir)
        if not datasets:
            print(f"Error: No valid datasets found in {root_dir}!")
            return
    
    print(f"Found {len(datasets)} datasets: {', '.join(datasets)}")
    
    start_time = time.time()
    
    successful = []
    failed = []
    
    for i, dataset in enumerate(datasets, 1):
        print(f"\n[{i}/{len(datasets)}] Starting to process dataset: {dataset}")
        if process_dataset(dataset, fold, epochs, root_dir, link):
            successful.append(dataset)
        else:
            failed.append(dataset)
    
    elapsed = time.time() - start_time
    hours = int(elapsed // 3600)
    minutes = int((elapsed % 3600) // 60)
    seconds = int(elapsed % 60)
    
    print("\n" + "="*70)
    print(f"All datasets processed! Total time: {hours}h {minutes}m {seconds}s")
    print(f"Successful: {len(successful)}/{len(datasets)} - {', '.join(successful)}")
    if failed:
        print(f"Failed: {len(failed)}/{len(datasets)} - {', '.join(failed)}")
    print("="*70)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run ordinal regression experiments with fixed random seed and visualize feature spaces at each epoch")
    parser.add_argument("--dataset", help="Specific dataset name, process all datasets if not specified")
    parser.add_argument("--link", type=str, default="logit", choices=["logit", "probit"],
                        help="Link function type")
    args = parser.parse_args()
    
    fold = 29
    epochs = 5000
    root_dir = "../datasets-orreview/ordinal-regression"
    
    main(fold, epochs, root_dir, args.dataset, args.link)