from utils import *
from models.gnn_models import *
from models.egnn_models import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from torch_geometric.loader import DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import argparse

set_random_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


parser = argparse.ArgumentParser(description="Link Prediction with Various GNN Models")
parser.add_argument('--model', type=str, default='gcn',
                    choices=['gcn', 'sage', 'gat', 'egnn', 'equiformer'])
parser.add_argument('--hid_dim', type=int, default=32)
parser.add_argument('--dataset', type=str, default='nanowire', choices=['nanofibres', 'nanowire', 'polymer'], help="Dataset to use")
args = parser.parse_args()




def create_data_splits(data_list, test_size=0.2, val_size=0.2, random_state=42):
    """Create train/val/test splits for graph classification"""
    labels = [data.y.item() for data in data_list]
    
    # Train/test split
    train_data, test_data, train_labels, test_labels = train_test_split(
        data_list, labels, test_size=test_size, stratify=labels, random_state=random_state
    )
    
    # Train/val split
    train_data, val_data, train_labels, val_labels = train_test_split(
        train_data, train_labels, test_size=val_size, stratify=train_labels, random_state=random_state
    )
    
    print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
    print(f"Train distribution: {dict(zip(*np.unique(train_labels, return_counts=True)))}")
    print(f"Val distribution: {dict(zip(*np.unique(val_labels, return_counts=True)))}")
    print(f"Test distribution: {dict(zip(*np.unique(test_labels, return_counts=True)))}")
    
    return train_data, val_data, test_data


def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = F.mse_loss(out.squeeze(), batch.y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend(out.squeeze().detach().cpu().numpy())
        targets.extend(batch.y.cpu().numpy())
    
    mse = mean_squared_error(targets, predictions)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    
    return total_loss / len(loader), mse, mae, r2

@torch.no_grad()
def test_epoch(model, loader, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = F.mse_loss(out.squeeze(), batch.y)
        
        total_loss += loss.item()
        predictions.extend(out.squeeze().cpu().numpy())
        targets.extend(batch.y.cpu().numpy())
    
    targets = np.array(targets)
    predictions = np.array(predictions)
    mse = mean_squared_error(targets, predictions)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    rel_error = np.mean(np.abs(targets - predictions) / (np.abs(targets) + 1e-8))

    return total_loss / len(loader), mse, mae, r2, predictions, targets, rel_error


def get_model(args, in_dim=32):
    if args.model == 'gcn':
        model = GCNGraphRegressor(in_channels=in_dim, hidden_channels=64).to(device)
    elif args.model == 'sage':
        model = SAGEGraphRegressor(in_channels=in_dim, hidden_channels=64).to(device)
    elif args.model == 'gat':
        model = GATGraphRegressor(in_channels=in_dim, hidden_channels=64, heads=4).to(device)
    elif args.model == 'egnn':
        model = BatchedEGNNGraphRegressor(
            in_channels=in_dim,
            hidden_channels=args.hid_dim,
            dropout=0.1,
            batch_size=512  # Adjust based on your GPU memory
        ).to(device)
    elif args.model == 'equiformer':
        model = EquiformerGraphRegressor(
            in_channels=in_dim,
            hidden_channels=args.hid_dim,
            num_layers=2,
            max_ell=2,
            dropout=0.1
        ).to(device)
    else:
        raise ValueError(f"Unknown model type: {args.model}")

    lr = 0.01 if args.model == 'equiformer' else 0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, optimizer



def run_experiment(data_list, train_data, val_data, test_data):
    """Run one experiment with a given random state"""

    # Create data loaders
    batch_size = 4 if args.model=='egnn' else 16

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    # Model setup
    model, optimizer = get_model(args, in_dim=data_list[0].num_node_features)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    
    # Training
    best_val_loss = float('inf')
    patience = 50
    epochs_no_improve = 0
    best_model_state = None
    max_epochs = 500

    for epoch in range(max_epochs):
        train_loss, train_mse, train_mae, train_r2 = train_epoch(model, train_loader, optimizer, device)
        val_loss, val_mse, val_mae, val_r2, _, _,_ = test_epoch(model, val_loader, device)
        
        scheduler.step(val_loss)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Train R²: {train_r2:.4f}, '
                  f'Val Loss: {val_loss:.4f}, Val R²: {val_r2:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}. Best Val Loss: {best_val_loss:.4f}")
            break
    
    # Load best model and test
    model.load_state_dict(best_model_state)
    test_loss, test_mse, test_mae, test_r2, test_preds, test_targets, test_rel_error = test_epoch(model, test_loader, device)

    return {
        'test_mse': test_mse,
        'test_mae': test_mae,
        'test_r2': test_r2,
        'predictions': test_preds,
        'targets': test_targets,
        'relative_error': test_rel_error
    }

def main():
    print("Loading isotropic network data for graph regression...")
    
    # Load data
    data_list = get_nanowire_network()
    # Create splits
    train_data, val_data, test_data = create_data_splits(data_list, random_state=42)

    # Run multiple experiments
    num_runs = 3
    results = {'mse': [], 'mae': [], 'r2': [], 're': []}
    
    for run in range(num_runs):
        print(f"\n=== Run {run + 1}/{num_runs} ===")
        result = run_experiment(data_list, train_data, val_data, test_data)
        
        results['mse'].append(result['test_mse'])
        results['mae'].append(result['test_mae'])
        results['r2'].append(result['test_r2'])
        results['re'].append(result['relative_error'])

        print(f"Test MSE: {result['test_mse']:.4f} | \
            MAE: {result['test_mae']:.4f} | \
            R²: {result['test_r2']:.4f} | \
            RelErr: {result['relative_error']:.4f}")

    # Print final results
    print(f"\n=== Final Results (Mean ± Std over {num_runs} runs) ===")
    print(f"Test MSE: {np.mean(results['mse']):.4f} ± {np.std(results['mse']):.4f}")
    print(f"Test MAE: {np.mean(results['mae']):.4f} ± {np.std(results['mae']):.4f}")
    print(f"Test R²: {np.mean(results['r2']):.4f} ± {np.std(results['r2']):.4f}")
    print(f"Test Relative Error: {np.mean(results['re']):.4f} ± {np.std(results['re']):.4f}")

if __name__ == "__main__":
    main()