import os
import csv
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from datetime import datetime
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold

from CTQWformer_2026AAAI_submission.utils.Datasetloader import load_and_preprocess_tu_dataset
from CTQWformer import CTQWformer
from QWEncoder import CTQWEncoder, T_Ham


def preprocess_all_qw(dataset, hamiltonian_model, ctqw_encoder, device):
    """
    Precompute CTQW evolution tensors (Q) for all graphs in the dataset.
    """
    qw_probs_list = []
    new_data_list = []

    for idx, data in enumerate(dataset):
        data = data.to(device)
        data.idx = idx

        H = hamiltonian_model(
            edge_index=data.edge_index,
            num_nodes=data.num_nodes,
            x=data.x  # Use node features as input to the Hamiltonian
        )

        Q = ctqw_encoder(H)  # [T, N, N]
        qw_probs_list.append(Q.cpu())
        new_data_list.append(data.cpu())

    return new_data_list, qw_probs_list


def generate_log_path(dataset_name):
    """
    Generate a timestamped result log file path.
    """
    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    folder = "results"
    os.makedirs(folder, exist_ok=True)
    return os.path.join(folder, f"{dataset_name}_{time_str}.csv")


def train_ctqw_model(
    dataset_name='MUTAG',
    epochs=100,
    folds=10,
    time_steps=torch.tensor([1.0, 2.0]),
    hidden_dim=64,
    lr=1e-4,
    batch_size=1,
    fusion='cat',
    heads=4,
    use_attention_bias=True,
    use_sequence_model=True,
    num_layers=4,
    dropout=0.3,
    earlystop_patience=20,
    device=None,
    result_log_path=None,
    return_embeddings=False,
):
    """
    Train the CTQWformer model on a TU dataset using k-fold cross validation.
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    raw_dataset = load_and_preprocess_tu_dataset(dataset_name)
    input_dim = raw_dataset[0].x.size(1)
    num_classes = len(set(data.y.item() for data in raw_dataset))

    ham_model = T_Ham(in_dim=input_dim).to(device)
    ctqw_encoder = CTQWEncoder(time_steps).to(device)

    raw_dataset, qw_all = preprocess_all_qw(raw_dataset, ham_model, ctqw_encoder, device)

    targets = [data.y.item() for data in raw_dataset]
    skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)
    all_acc = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(raw_dataset, targets)):
        print(f"\n=== Fold {fold + 1}/{folds} ===")

        # Assign precomputed QW probabilities
        for i in train_idx:
            raw_dataset[i].qw_probs = qw_all[i]
        for i in test_idx:
            raw_dataset[i].qw_probs = qw_all[i]

        train_loader = DataLoader([raw_dataset[i] for i in train_idx], batch_size=batch_size, shuffle=True)
        test_loader = DataLoader([raw_dataset[i] for i in test_idx], batch_size=batch_size, shuffle=False)

        model = CTQWformer(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            time_steps=time_steps,
            num_layers=num_layers,
            heads=heads,
            fusion=fusion,
            use_attention_bias=use_attention_bias,
            use_sequence_model=use_sequence_model,
            dropout=dropout
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )

        best_acc = 0
        patience_counter = 0

        for epoch in range(1, epochs + 1):
            model.train()
            total_loss = 0

            for data in train_loader:
                data = data.to(device)
                qw_probs = data.qw_probs.to(device)
                if qw_probs.requires_grad:
                    qw_probs = qw_probs.detach()

                optimizer.zero_grad()
                out = model(data, qw_probs.unsqueeze(0))
                loss = F.cross_entropy(out, data.y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(train_loader)

            model.eval()
            correct = 0
            with torch.no_grad():
                for data in test_loader:
                    data = data.to(device)
                    qw_probs = data.qw_probs.to(device)
                    out = model(data, qw_probs.unsqueeze(0))
                    pred = out.argmax(dim=1)
                    correct += (pred == data.y).sum().item()

            acc = correct / len(test_loader)
            print(f"Epoch {epoch:03d}: Loss = {avg_loss:.4f}, Acc = {acc:.4f}")

            scheduler.step(avg_loss)

            if acc > best_acc:
                best_acc = acc
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= earlystop_patience:
                    print(f"⏹️ Early stopping at epoch {epoch}. Best acc: {best_acc:.4f}")
                    break

        all_acc.append(best_acc)

    avg_acc = round(float(np.mean(all_acc)), 4)
    std_acc = round(float(np.std(all_acc)), 4)

    print(f"\n=== Final Avg Accuracy over {folds} folds: {avg_acc:.4f} ± {std_acc:.4f} ===")

    if result_log_path is None:
        result_log_path = generate_log_path(dataset_name)

    file_exists = os.path.exists(result_log_path)
    with open(result_log_path, 'a', newline='') as csvfile:
        fieldnames = [
            "dataset", "avg_acc", "std_acc", "epochs", "folds", "hidden_dim", "lr",
            "batch_size", "dropout", "heads", "fusion",
            "use_attention_bias", "use_sequence_model", "num_layers", "time_steps"
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow({
            "dataset": dataset_name,
            "avg_acc": avg_acc,
            "std_acc": std_acc,
            "epochs": epochs,
            "folds": folds,
            "hidden_dim": hidden_dim,
            "lr": lr,
            "batch_size": batch_size,
            "dropout": dropout,
            "heads": heads,
            "fusion": fusion,
            "use_attention_bias": use_attention_bias,
            "use_sequence_model": use_sequence_model,
            "num_layers": num_layers,
            "time_steps": '|'.join(map(str, time_steps.tolist())) if isinstance(time_steps, torch.Tensor) else str(time_steps)
        })

    print(f"✅ Results saved to {result_log_path}")

    return avg_acc
