import numpy as np
import torch
from torch import nn, optim
from sklearn.metrics import precision_score
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import pandas as pd
import copy

from tandem import GatingNet, Tandem


# --------- Insert model code here: GatingNet, OSDT, OSDTEncoder, ModularEncoder, ModularDecoder, SSAE, Tandem ----------

# ----- EMBEDDING EXTRACTION -----
def extract_embeddings(model, x):
    model.eval()
    with torch.no_grad():
        if hasattr(model, 'nn_encoder') and hasattr(model, 'osdt_encoder'):
            _, _, z_nn, z_osdt = model(x)
            return z_nn, z_osdt
        elif hasattr(model, 'encoder'):
            z = model.encoder(x)
            return z, None
        else:
            raise ValueError("Unknown model structure for embedding extraction.")

# ----- FINETUNE WITH LAYER-SPECIFIC LR AND FREEZING -----
def finetune_with_classifier(
    model, X_train, y_train, X_val, y_val,
    finetune_epochs=50, learning_rate=0.01, weight_decay=1e-4, device="cpu"
):
    z_nn, z_osdt = extract_embeddings(model, X_train)
    latent_dim = z_nn.shape[1] + (z_osdt.shape[1] if z_osdt is not None else 0)
    num_classes = len(np.unique(y_train))

    classifier = nn.Linear(latent_dim, num_classes).to(device)

    class JointModel(nn.Module):
        def __init__(self, model, classifier):
            super().__init__()
            self.model = model
            self.classifier = classifier
        def forward(self, x):
            z_nn, z_osdt = extract_embeddings(self.model, x)
            if z_osdt is not None:
                z = torch.cat([z_nn, z_osdt], dim=1)
            else:
                z = z_nn
            return self.classifier(z)

    joint_model = JointModel(model, classifier).to(device)

    # Encoder parameters (for freezing and diff LR)
    encoder_params = []
    if hasattr(model, 'nn_encoder'):
        encoder_params += list(model.nn_encoder.parameters())
    elif hasattr(model, 'encoder'):
        encoder_params += list(model.encoder.parameters())
    if hasattr(model, 'osdt_encoder'):
        encoder_params += list(model.osdt_encoder.parameters())

    classifier_params = list(classifier.parameters())

    # Adam with per-group LR
    optimizer = optim.Adam([
        {'params': encoder_params, 'lr': 0.5 * learning_rate},
        {'params': classifier_params, 'lr': learning_rate}
    ], weight_decay=weight_decay)

    loss_fn = nn.CrossEntropyLoss()
    y_train_t = torch.tensor(y_train, dtype=torch.long, device=device)
    y_val_t = torch.tensor(y_val, dtype=torch.long, device=device)

    best_val = 0
    best_state = None
    patience, patience_limit = 0, 10
    total_epochs = finetune_epochs

    for epoch in range(total_epochs):
        joint_model.train()
        optimizer.zero_grad()

        # === Freeze encoder in first half of epochs ===
        freeze_encoder = (epoch < total_epochs // 2)
        for param in encoder_params:
            param.requires_grad = not freeze_encoder

        out = joint_model(X_train)
        loss = loss_fn(out, y_train_t)
        loss.backward()
        optimizer.step()

        joint_model.eval()
        with torch.no_grad():
            val_logits = joint_model(X_val)
            preds = torch.argmax(val_logits, 1).cpu().numpy()
            val_prec = precision_score(y_val, preds, average="macro")
        if val_prec > best_val:
            best_val = val_prec
            best_state = copy.deepcopy(joint_model.state_dict())
            patience = 0
        else:
            patience += 1
            if patience >= patience_limit:
                break
    if best_state:
        joint_model.load_state_dict(best_state)
    joint_model.eval()
    with torch.no_grad():
        preds = torch.argmax(joint_model(X_train), 1).cpu().numpy()
        train_prec = precision_score(y_train, preds, average="macro")
    return train_prec, best_val, joint_model

# ----- DOWNSTREAM EMBEDDING EVALUATION -----
def evaluate_embeddings(
    data, labels, model, sample_sizes, n_trials,
    classifier_fn=None, finetune_encoder=False,
    finetune_epochs=50, learning_rate=0.01, device="cpu"
):
    labels = np.asarray(labels).astype(int)
    X_all = torch.tensor(data.values, dtype=torch.float32, device=device)
    results_per_samples = []

    for sample_size in sample_sizes:
        trial_scores = []
        for trial in range(n_trials):
            # Stratified split for reproducibility and fair label dist.
            idx_train, idx_val = train_test_split(
                np.arange(len(data)), train_size=sample_size,
                stratify=labels, random_state=trial
            )
            X_train = X_all[idx_train]
            y_train = labels.iloc[idx_train].to_numpy()
            X_val = X_all[idx_val]
            y_val = labels.iloc[idx_val].to_numpy()

            if finetune_encoder:
                train_prec, val_prec, _ = finetune_with_classifier(
                    model, X_train, y_train, X_val, y_val,
                    finetune_epochs=finetune_epochs, learning_rate=learning_rate, device=device
                )
                trial_scores.append(val_prec)
            else:
                if classifier_fn is None:
                    raise ValueError("classifier_fn must be provided if finetune_encoder=False.")
                z_nn, z_osdt = extract_embeddings(model, X_train)
                z_nn_val, z_osdt_val = extract_embeddings(model, X_val)
                if z_osdt is not None and z_osdt_val is not None:
                    emb_train = torch.cat([z_nn, z_osdt], dim=1).cpu().numpy()
                    emb_val = torch.cat([z_nn_val, z_osdt_val], dim=1).cpu().numpy()
                else:
                    emb_train = z_nn.cpu().numpy()
                    emb_val = z_nn_val.cpu().numpy()
                clf = classifier_fn()
                clf.fit(emb_train, y_train)
                y_pred = clf.predict(emb_val)
                trial_scores.append(precision_score(y_val, y_pred, average="macro"))
        results_per_samples.append(np.mean(trial_scores))
    return results_per_samples

# ========== RUN ON TOY DATASET ==========
if __name__ == "__main__":
    from sklearn.linear_model import LogisticRegression

    # -- Toy dataset: 2000 samples, 48 features, 3 classes --
    X, y = make_classification(
        n_samples=2000, n_features=48, n_informative=35, n_redundant=10,
        n_classes=3, n_clusters_per_class=1, random_state=42
    )
    data = pd.DataFrame(X)
    labels = pd.Series(y)

    # -- Model: TANDEM (as in your code) --
    input_dim = 48
    osdt_depth = 7
    final_latent = 2 ** osdt_depth
    hidden_layers = [128, 256, 128, final_latent]
    num_trees = 2

    gating_net = GatingNet(input_dim, hidden_dim=128)
    tandem = Tandem(
        input_dim=input_dim,
        hidden_layers=hidden_layers,
        num_trees=num_trees,
        osdt_depth=osdt_depth,
        gating_net=gating_net
    )

    # -- Downstream evaluation with Logistic Regression --
    print("== Downstream macro-precision (no finetune):")
    results = evaluate_embeddings(
        data, labels, tandem, sample_sizes=[100, 300, 1000], n_trials=3,
        classifier_fn=lambda: LogisticRegression(max_iter=200, solver='liblinear'),
        finetune_encoder=False
    )
    for size, score in zip([100, 300, 1000], results):
        print(f"  n={size}: {score:.4f}")

    print("\n== Finetune macro-precision (encoder frozen first half, per-layer LR):")
    results_ft = evaluate_embeddings(
        data, labels, tandem, sample_sizes=[300], n_trials=2,
        classifier_fn=None, finetune_encoder=True, finetune_epochs=20, learning_rate=0.01
    )
    for size, score in zip([300], results_ft):
        print(f"  n={size}: {score:.4f}")
