import os
import sys
import torch
from pathlib import Path
import argparse
import numpy as np
from torch_geometric.data import Data, DataLoader

print("=" * 70)
print("SIMPLIFIED Virtual Dataset Training")
print("=" * 70)


def create_simple_virtual_dataset(batch_size=8, num_graphs=200, num_classes=4):
    """Create simple virtual dataset"""
    print("Creating simple virtual dataset...")

    all_data = []
    feature_dim = 16

    # Create learnable dataset
    for i in range(num_graphs):
        class_idx = i % num_classes
        num_nodes = np.random.randint(20, 50)

        # Class-related features
        base_noise = np.random.randn(num_nodes, feature_dim) * 0.2
        class_pattern = np.ones((num_nodes, feature_dim)) * (class_idx * 0.3)

        # Add some learnable patterns
        if class_idx == 0:
            pattern = np.sin(np.linspace(0, 2*np.pi, feature_dim)).reshape(1, -1)
        elif class_idx == 1:
            pattern = np.cos(np.linspace(0, 2*np.pi, feature_dim)).reshape(1, -1)
        elif class_idx == 2:
            pattern = np.tanh(np.linspace(-2, 2, feature_dim)).reshape(1, -1)
        else:
            pattern = np.exp(np.linspace(-1, 1, feature_dim)).reshape(1, -1)

        x = torch.tensor(base_noise + class_pattern + pattern * 0.2, dtype=torch.float32)

        # Create edges
        num_edges = np.random.randint(num_nodes * 2, num_nodes * 3)
        edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.long)

        y = torch.tensor([class_idx], dtype=torch.long)

        all_data.append(Data(x=x, edge_index=edge_index, y=y))

    # Shuffle data
    np.random.shuffle(all_data)

    # Split dataset
    train_ratio = 0.6
    val_ratio = 0.2
    test_ratio = 0.2

    num_train = int(len(all_data) * train_ratio)
    num_val = int(len(all_data) * val_ratio)
    num_test = len(all_data) - num_train - num_val

    train_data = all_data[:num_train]
    val_data = all_data[num_train:num_train + num_val]
    test_data = all_data[num_train + num_val:]

    print(f"✅ Simple dataset created:")
    print(f"   Training: {len(train_data)} graphs")
    print(f"   Validation: {len(val_data)} graphs")
    print(f"   Test: {len(test_data)} graphs")

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


def main():
    parser = argparse.ArgumentParser(description='Simplified virtual dataset training')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension')
    parser.add_argument('--dropout', type=float, default=0.3, help='Dropout rate')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')

    args = parser.parse_args()

    # Clear cache
    modules_to_clear = [m for m in sys.modules.keys() if 'models' in m]
    for m in modules_to_clear:
        sys.modules.pop(m)

    # Add project path
    project_root = Path(__file__).parent.parent
    sys.path.insert(0, str(project_root))

    # 1. Create dataset
    print("\n1. Creating dataset...")
    train_loader, val_loader, test_loader = create_simple_virtual_dataset(
        batch_size=args.batch_size,
        num_graphs=200,
        num_classes=4
    )

    # 2. Create model
    print("\n2. Creating model...")

    try:
        from models import DIGLModel
        print("✅ Imported DIGLModel")
    except ImportError:
        print("❌ Could not import DIGLModel")
        return

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"   Device: {device}")

    model = DIGLModel(
        in_dim=16,
        hidden_dim=args.hidden_dim,
        out_dim=4,
        dropout=args.dropout,
        num_environments=2,
        use_wasserstein=False,
        use_causal_intervention=False
    ).to(device)

    print(f"   Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # 3. Training
    print("\n3. Training...")

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_val_acc = 0
    best_model_state = None

    for epoch in range(args.epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            output = model(batch, labels=batch.y, training=True)
            loss = output['loss']

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            preds = output['logits'].argmax(dim=1)
            train_correct += (preds == batch.y).sum().item()
            train_total += batch.y.shape[0]

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                output = model(batch, labels=batch.y, training=False)

                val_loss += output['loss'].item()

                preds = output['logits'].argmax(dim=1)
                val_correct += (preds == batch.y).sum().item()
                val_total += batch.y.shape[0]

        # Calculate metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = train_correct / train_total if train_total > 0 else 0
        val_acc = val_correct / val_total if val_total > 0 else 0

        print(f"   Epoch {epoch + 1:3d}/{args.epochs}: "
              f"Train Loss={avg_train_loss:.4f}, Train Acc={train_acc:.2%}, "
              f"Val Loss={avg_val_loss:.4f}, Val Acc={val_acc:.2%}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            print(f"     ↳ New best validation accuracy: {val_acc:.2%}")

    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\n   Restored best model with validation accuracy: {best_val_acc:.2%}")

    # 4. Testing
    print("\n4. Testing...")
    model.eval()
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            output = model(batch, labels=batch.y, training=False)

            preds = output['logits'].argmax(dim=1)
            test_correct += (preds == batch.y).sum().item()
            test_total += batch.y.shape[0]

    test_acc = test_correct / test_total if test_total > 0 else 0

    print("\n" + "=" * 70)
    print("🎯 TRAINING COMPLETED!")
    print(f"   Best validation accuracy: {best_val_acc:.2%}")
    print(f"   Test accuracy: {test_acc:.2%}")
    print("=" * 70)

    # 5. Save results
    output_dir = Path("./results/simple")
    output_dir.mkdir(parents=True, exist_ok=True)

    torch.save({
        'model_state_dict': model.state_dict(),
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'args': vars(args)
    }, output_dir / "model.pth")

    print(f"\nModel saved to: {output_dir / 'model.pth'}")


if __name__ == "__main__":
    main()