import argparse
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import itertools
import json
import os
from datetime import datetime

from models import NetWithPE, LaplacianPE
from dataset_handler import get_dataset

def train(model, train_loader, optimizer, device):
    """Training function"""
    model.train()
    total_loss = 0
    
    for data in tqdm(train_loader, desc='Training', leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(train_loader.dataset)

def validate(model, loader, device):
    """Validation function"""
    model.eval()
    total_loss = 0
    correct = 0
    
    for data in tqdm(loader, desc='Validating', leave=False):
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            loss = F.nll_loss(output, data.y)
            total_loss += loss.item() * data.num_graphs
            pred = output.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()
    
    accuracy = correct / len(loader.dataset)
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, accuracy

def plot_training_curves(results, save_path='training_curves.png'):
    """Plot training and validation curves"""
    plt.figure(figsize=(12, 6))
    epochs = range(1, len(results['train_loss']) + 1)
    
    plt.plot(epochs, results['train_loss'], label='Train Loss', linestyle='-')
    plt.plot(epochs, results['val_loss'], label='Val Loss', linestyle='--')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Metrics')
    plt.legend()
    plt.grid(True)
    
    # Add accuracy on secondary axis
    ax2 = plt.twinx()
    ax2.plot(epochs, results['train_acc'], label='Train Acc', linestyle=':', color='g')
    ax2.plot(epochs, results['val_acc'], label='Val Acc', linestyle='-.', color='r')
    ax2.set_ylabel('Accuracy')
    
    plt.legend(loc='center right')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def get_model(dataset, args, device):
    """Helper function to initialize the appropriate model"""
    num_features = dataset.num_features if hasattr(dataset, 'num_features') else 1
    num_classes = dataset.num_classes if hasattr(dataset, 'num_classes') else 10
    
    model = NetWithPE(
        num_features=num_features,
        pos_enc_dim=args.k,
        num_classes=num_classes
    ).to(device)
    
    return model


def get_dataset_config(dataset_name):
    """
    Returns default configuration for each dataset
    """
    configs = {
        'mnist': {
            'batch_size': 32,
            'learning_rate': 0.01,
            'hidden_channels': 64,
            'dropout': 0.2,
            'weight_decay': 5e-4,
            'pos_enc_dim': 5,
            'description': 'Superpixel representation of MNIST digits'
        },
        'proteins': {
            'batch_size': 32,
            'learning_rate': 0.001,
            'hidden_channels': 128,
            'dropout': 0.3,
            'weight_decay': 1e-4,
            'pos_enc_dim': 3,
            'description': 'Protein structures as graphs'
        },
        'mutag': {
            'batch_size': 32,
            'learning_rate': 0.005,
            'hidden_channels': 32,
            'dropout': 0.5,
            'weight_decay': 1e-3,
            'pos_enc_dim': 8,
            'description': 'Molecule graphs for mutagenicity prediction'
        },
        'imdb-binary': {
            'batch_size': 32,
            'learning_rate': 0.0005,
            'hidden_channels': 64,
            'dropout': 0.5,
            'weight_decay': 1e-3,
            'pos_enc_dim': 32,
            'description': 'Movie collaboration graphs for genre classification'
        },
        'reddit-binary': {
            'batch_size': 128,
            'learning_rate': 0.0001,
            'hidden_channels': 128,
            'dropout': 0.5,
            'weight_decay': 5e-3,
            'pos_enc_dim': 64,
            'description': 'Reddit discussion graphs for community type prediction'
        },
        'collab': {
            'batch_size': 64,
            'learning_rate': 0.0005,
            'hidden_channels': 128,
            'dropout': 0.3,
            'weight_decay': 1e-4,
            'pos_enc_dim': 32,
            'description': 'Scientific collaboration graphs'
        },
        'enzymes': {
            'batch_size': 32,
            'learning_rate': 0.01,
            'hidden_channels': 64,
            'dropout': 0.4,
            'weight_decay': 1e-4,
            'pos_enc_dim': 16,
            'description': 'Enzyme structure graphs'
        }
    }
    
    # Default config if dataset not found
    default_config = {
        'batch_size': 32,
        'learning_rate': 0.001,
        'hidden_channels': 64,
        'dropout': 0.5,
        'weight_decay': 1e-4,
        'pos_enc_dim': 16,
        'description': 'Default configuration'
    }
    
    return configs.get(dataset_name.lower(), default_config)

def get_all_flag_combinations():
    """Generate all possible combinations of boolean flags"""
    flags = ['with_pos_enc', 'with_proj', 'with_virtual_node', 'use_node_attr']
    combinations = []
    
    # Generate all possible combinations (0 or 1 for each flag)
    for bits in itertools.product([False, True], repeat=len(flags)):
        combination = dict(zip(flags, bits))
        combinations.append(combination)
    
    return combinations

def get_experiment_name(dataset_name, flags):
    """Create a readable experiment name from flags"""
    components = [dataset_name]
    if flags['with_pos_enc']:
        components.append('pe')
    if flags['with_proj']:
        components.append('proj')
    if flags['with_virtual_node']:
        components.append('vn')
    if flags['use_node_attr']:
        components.append('attr')
    
    return '_'.join(components)

def run_single_experiment(args, flags, device, config):
    """Run a single experiment with given flag configuration"""
    # Update args with current flag combination
    for flag_name, flag_value in flags.items():
        setattr(args, flag_name, flag_value)
    
    # Create experiment name and directory
    experiment_name = get_experiment_name(args.dataset, flags)
    results_dir = f'results_{args.dataset}'
    os.makedirs(results_dir, exist_ok=True)
    
    print(f'\nRunning experiment: {experiment_name}')
    print('Configuration:')
    for flag_name, flag_value in flags.items():
        print(f'- {flag_name}: {flag_value}')
    
    # Define transforms
    transform = T.Compose([
        LaplacianPE(k=args.k, 
                   with_pos_enc=args.with_pos_enc,
                   with_proj=args.with_proj,
                   with_virtual=args.with_virtual_node),
        T.NormalizeFeatures()
    ])
    
    # Load dataset
    train_dataset, test_dataset = get_dataset(
        args.dataset,
        args.data_dir,
        transform,
        args.k,
        args.use_node_attr
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size
    )
    
    # Initialize model
    model = NetWithPE(
        num_features=train_dataset.num_features,
        hidden_channels=args.hidden_channels,
        pos_enc_dim=args.k,
        num_classes=train_dataset.num_classes,
    ).to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), 
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    
    # Initialize metrics storage
    results = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'flags': flags,
        'config': {
            'lr': args.lr,
            'hidden_channels': args.hidden_channels,
            'k': args.k,
            'batch_size': args.batch_size,
            'dropout': args.dropout,
            'weight_decay': args.weight_decay
        }
    }
    
    # Training loop
    best_val_acc = 0
    for epoch in range(1, args.epochs + 1):
        # Train
        train_loss = train(model, train_loader, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate(model, test_loader, device)
        train_loss_val, train_acc = validate(model, train_loader, device)
        
        # Store metrics
        results['train_loss'].append(train_loss_val)
        results['val_loss'].append(val_loss)
        results['train_acc'].append(train_acc)
        results['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'flags': flags,
                'config': results['config']
            }, f'{results_dir}/best_model_{experiment_name}.pt')
        
        print(f'Epoch: {epoch:02d}, '
              f'Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, '
              f'Val Acc: {val_acc:.4f}, '
              f'Best Val Acc: {best_val_acc:.4f}')
    
    # Save final results
    results['best_val_acc'] = best_val_acc
    with open(f'{results_dir}/results_{experiment_name}.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    # Plot training curves
    plot_training_curves(results, save_path=f'{results_dir}/training_curves_{experiment_name}.png')
    
    return results

def plot_combined_results(all_results, dataset_name):
    """Plot combined results for all experiments"""
    plt.figure(figsize=(15, 10))
    
    # Plot validation accuracy for each experiment
    for experiment_name, results in all_results.items():
        epochs = range(1, len(results['val_acc']) + 1)
        plt.plot(epochs, results['val_acc'], label=experiment_name)
    
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title(f'Validation Accuracy for Different Configurations on {dataset_name}')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'results_{dataset_name}/combined_results.png')
    plt.close()

def main():
        # Available datasets
    available_datasets = ['mnist', 'PROTEINS', 'MUTAG', 'IMDB-BINARY', 
                         'REDDIT-BINARY', 'COLLAB', 'ENZYMES']
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='mnist', choices=available_datasets,
                       help='Dataset name')
    parser.add_argument('--k', type=int, default=3, help='Dimension of positional encoding')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--data_dir', type=str, default='data', help='Data directory')
    parser.add_argument('--with_pos_enc', action='store_true',
                       help='Use positional encoding')
    parser.add_argument('--with_proj', action='store_true',
                       help='Use edge projections')
    parser.add_argument('--with_virtual_node', action='store_true',
                       help='Add virtual node to graphs')
    parser.add_argument('--use_node_attr', action='store_true',
                       help='Use additional node attributes for TU datasets if available')
    
    args = parser.parse_args()
    
    # Get dataset-specific configuration
    config = get_dataset_config(args.dataset)
    
    # Set default values from config
    if args.k is None:
        args.k = config['pos_enc_dim']
    if args.batch_size is None:
        args.batch_size = config['batch_size']
    if args.lr is None:
        args.lr = config['learning_rate']
    args.hidden_channels = config['hidden_channels']
    args.dropout = config['dropout']
    args.weight_decay = config['weight_decay']
    
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
    
    # Get all possible flag combinations
    flag_combinations = get_all_flag_combinations()
    
    # Create results directory
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_dir = f'results_{args.dataset}_{timestamp}'
    os.makedirs(results_dir, exist_ok=True)
    
    # Save experiment configuration
    config_path = f'{results_dir}/experiment_config.json'
    with open(config_path, 'w') as f:
        json.dump(vars(args), f, indent=4)
    
    # Run experiments for all combinations
    all_results = {}
    for flags in tqdm(flag_combinations, desc='Running experiments'):
        experiment_name = get_experiment_name(args.dataset, flags)
        results = run_single_experiment(args, flags, device, config)
        all_results[experiment_name] = results
    
    # Save combined results
    combined_results_path = f'{results_dir}/all_results.json'
    with open(combined_results_path, 'w') as f:
        json.dump(all_results, f, indent=4)
    
    # Plot combined results
    plot_combined_results(all_results, args.dataset)
    
    # Print final summary
    print('\nExperiment Summary:')
    print(f'Dataset: {args.dataset}')
    print('\nBest results for each configuration:')
    for experiment_name, results in all_results.items():
        print(f'{experiment_name}: {results["best_val_acc"]:.4f}')
    
    print(f'\nAll results saved in: {results_dir}')

if __name__ == '__main__':
    main()