import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from sklearn.linear_model import LogisticRegression

# ==== Your models ====
from tandem import GatingNet, Tandem, SSAE

# ==== Utility ====
def load_openml_processed(file_path, label_column='label'):
    df = pd.read_csv(file_path)
    X = df.drop(columns=[label_column])
    y = df[label_column]
    return X, y

def cosine_loss(z1, z2):
    z1_norm = z1 / (z1.norm(dim=1, keepdim=True) + 1e-8)
    z2_norm = z2 / (z2.norm(dim=1, keepdim=True) + 1e-8)
    return (1 - (z1_norm * z2_norm).sum(dim=1)).mean()

def generic_train(
    model, X_train, loss_fns, epochs=10, lr=1e-3, batch_size=64, device="cpu",
    loss_weights=None, verbose=True, **forward_kwargs
):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    X_train_t = torch.tensor(X_train.values, dtype=torch.float32, device=device)
    n = X_train.shape[0]
    mse = nn.MSELoss()
    if not isinstance(loss_fns, (list, tuple)):
        loss_fns = [loss_fns]
    if loss_weights is None:
        loss_weights = [1.0] * len(loss_fns)
    final_loss = None

    for epoch in range(epochs):
        model.train()
        idx = np.random.permutation(n)
        losses = []
        for i in range(0, n, batch_size):
            batch_idx = idx[i:i+batch_size]
            x = X_train_t[batch_idx]
            outputs = model(x, **forward_kwargs)
            if isinstance(outputs, tuple):
                # Dual-encoder: expects (recon_nn, recon_osdt, z_nn, z_osdt)
                recon_nn, recon_osdt, z_nn, z_osdt = outputs
                loss_list = [
                    mse(recon_nn, x),
                    mse(recon_osdt, x),
                    cosine_loss(z_nn, z_osdt)
                ]
            else:
                # Single-encoder: expects recon
                loss_list = [mse(outputs, x)]
            # Combine with weights
            total_loss = sum(w * l for w, l in zip(loss_weights, loss_list[:len(loss_weights)]))
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            losses.append(total_loss.item())
        final_loss = np.mean(losses)
        if verbose and (epoch == epochs - 1):
            print(f"Final Train Loss (epoch {epoch+1}): {final_loss:.4f}")
    return model, final_loss

def ensure_labels_int(labels):
    # Helper for robust label conversion
    if isinstance(labels, pd.Series):
        return labels.astype(int).reset_index(drop=True)
    return np.asarray(labels).astype(int)

def main():
    # ==== CONFIG ====
    dataset = 'default-of-credit-card-clients_categorical'
    osdt_depth = 7
    hidden_layers = [128, 256, 128, 2 ** osdt_depth]
    num_trees = 2
    epochs = 30
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # ==== LOAD DATA ====
    print(f"Loading data for dataset: {dataset}")
    base_dir = f'openml_datasets/{dataset}'
    train_file = os.path.join(base_dir, "train_data.csv")
    test_file = os.path.join(base_dir, "test_df.csv")
    label_col = 'label'
    X_train, y_train = load_openml_processed(train_file, label_col)
    X_test, y_test = load_openml_processed(test_file, label_col)
    input_dim = X_train.shape[1]

    # ==== MODEL: Example with TANDEM (dual encoder) ====
    gating_net = GatingNet(input_dim, hidden_dim=128)
    model = Tandem(
        input_dim=input_dim,
        hidden_layers=hidden_layers,
        num_trees=num_trees,
        osdt_depth=osdt_depth,
        gating_net=gating_net
    )

    # ---- For single encoder model (uncomment to try) ----
    # model = SSAE(input_dim=input_dim, hidden_layers=hidden_layers, osdt_depth=osdt_depth)

    print(f"\nTraining model {model.__class__.__name__} ...")
    if isinstance(model, Tandem):
        loss_fns = [nn.MSELoss(), nn.MSELoss(), cosine_loss]
        loss_weights = [1.0, 1.0, 1.0]   # [rec_nn, rec_osdt, cosine]
    else:
        loss_fns = nn.MSELoss()
        loss_weights = [1.0]
    model, final_train_loss = generic_train(
        model,
        X_train,
        loss_fns=loss_fns,
        loss_weights=loss_weights,
        epochs=epochs,
        device=device,
        verbose=True
    )

    print(f"Final Training Loss: {final_train_loss:.4f}")

    # ==== EVALUATE: Macro-Precision (no finetune) ====
    print("\nEvaluating downstream classifier (no finetune)...")
    from evaluate import evaluate_embeddings
    results_no_ft = evaluate_embeddings(
        X_test, ensure_labels_int(y_test), model, sample_sizes=[400], n_trials=5,
        classifier_fn=lambda: LogisticRegression(max_iter=200, solver='liblinear'),
        finetune_encoder=False, device=device
    )
    print(f"{model.__class__.__name__} | {dataset} | Macro-precision (no finetune): {results_no_ft}")

    # ==== EVALUATE: Macro-Precision (finetune) ====
    print("\nEvaluating downstream classifier (with finetune)...")
    results_ft = evaluate_embeddings(
        X_test, ensure_labels_int(y_test), model, sample_sizes=[400], n_trials=2,
        classifier_fn=None, finetune_encoder=True, finetune_epochs=20, learning_rate=0.01, device=device
    )
    print(f"{model.__class__.__name__} | {dataset} | Macro-precision (finetune): {results_ft}")

    # ==== PRINT SUMMARY ====
    print("\n==== FINAL RESULTS ====")
    print(f"{model.__class__.__name__} (no finetune): Mean Macro-precision: {results_no_ft[0]:.4f}")
    print(f"{model.__class__.__name__} (finetune):    Mean Macro-precision: {results_ft[0]:.4f}")

if __name__ == "__main__":
    main()
