import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Import models
from models import MC_AE_EGNN
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import get_laplacian, to_dense_adj

class EigenvectorProcessor(BaseTransform):
    """
    Transform to prepare eigenvectors and eigenspace projectors for MC_AE_EGNN model.
    - Sets eigenvectors as node features
    - Creates edge attributes using eigenspace projectors
    - Adds eigenvalues to the data object
    """
    def __init__(self, k=8):
        self.k = k
    
    def __call__(self, data):
        original_num_nodes = data.num_nodes
        
        # Compute the Laplacian
        edge_index, edge_weight = get_laplacian(
            data.edge_index, normalization='sym', 
            num_nodes=original_num_nodes
        )
        
        # Convert to dense adjacency matrix
        adj = to_dense_adj(edge_index, edge_attr=edge_weight)[0]
        
        try:
            # Compute eigendecomposition of the Laplacian
            eigvals, eigvecs = torch.linalg.eigh(adj)
            
            # Sort eigenvalues in ascending order
            idx = torch.argsort(eigvals)
            eigvals = eigvals[idx]
            eigvecs = eigvecs[:, idx]
            
            # Get top k eigenvectors (excluding the first constant one)
            eigvals_k = eigvals[1:self.k+1]
            eigvecs_k = eigvecs[:, 1:self.k+1]
            
            # Set the eigenvectors as node features
            data.pos = eigvecs_k
            
            # Store eigenvalues
            data.eigvals = eigvals_k
            
            # Compute eigenspace projectors for each edge
            row, col = data.edge_index
            edge_projectors = []
            
            # For each eigenvalue, compute the projection operator
            for i in range(self.k):
                # Extract eigenvector for this eigenvalue
                eigvec = eigvecs_k[:, i].unsqueeze(1)  # [num_nodes, 1]
                
                # Compute outer product for projector: v * v^T
                projector = torch.matmul(eigvec, eigvec.t())  # [num_nodes, num_nodes]
                
                # Extract projector values for each edge
                edge_proj_values = projector[row, col]  # [num_edges]
                edge_projectors.append(edge_proj_values)
            
            # Stack projector values for all eigenvalues
            data.edge_attr = torch.stack(edge_projectors, dim=1)  # [num_edges, k]
            
        except RuntimeError:
            # Fallback if eigendecomposition fails
            print(f"Warning: Eigendecomposition failed. Using zeros.")
            data.pos = torch.zeros((original_num_nodes, self.k))
            data.eigvals = torch.zeros(self.k)
            data.edge_attr = torch.zeros((data.edge_index.size(1), self.k))
        
        return data

def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    
    for data in tqdm(train_loader, desc='Training', leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(F.log_softmax(output, dim=1), data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(train_loader.dataset)

def validate(model, loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    
    for data in tqdm(loader, desc='Validating', leave=False):
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            loss = F.nll_loss(F.log_softmax(output, dim=1), data.y)
            total_loss += loss.item() * data.num_graphs
            pred = output.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()
    
    accuracy = correct / len(loader.dataset)
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, accuracy

def plot_training_curves(k_results, save_path='training_curves_mcae.png'):
    plt.figure(figsize=(12, 6))
    
    for k, results in k_results.items():
        epochs = range(1, len(results['train_loss']) + 1)
        plt.plot(epochs, results['train_loss'], 
                label=f'Train Loss (k={k})', linestyle='-')
        plt.plot(epochs, results['val_loss'], 
                label=f'Val Loss (k={k})', linestyle='--')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss for MC_AE_EGNN with Different k Values')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def main():
    parser = argparse.ArgumentParser()
    # Dataset and model parameters
    parser.add_argument('--k_values', nargs='+', type=int, default=[8], 
                      help='List of k values for positional encoding dimensions')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--data_dir', type=str, default='data', help='Data directory')
    
    # MC_AE_EGNN specific parameters
    parser.add_argument('--nf_edge', type=int, default=128, help='Hidden edge dimension')
    parser.add_argument('--nf_node', type=int, default=128, help='Hidden node dimension')
    parser.add_argument('--nf_coord', type=int, default=128, help='Hidden coordinate dimension')
    parser.add_argument('--n_layers', type=int, default=4, help='Number of layers')
    parser.add_argument('--update_vel', action='store_true', help='Update velocity')
    parser.add_argument('--with_coords', action='store_true', default=False, help='Use coordinates')
    parser.add_argument('--norm_diff', action='store_true', help='Normalize difference')
    parser.add_argument('--tanh', action='store_true', help='Use tanh activation')
    parser.add_argument('--num_vectors', type=int, default=1, help='Number of vectors')
    
    args = parser.parse_args()
    
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
    print(f'Using device: {device}')
    print(f'Training with eigenspace projectors')
    
    # Results dictionary to store metrics for each k
    k_results = {}
    
    # Train and evaluate for each k value
    for k in args.k_values:
        print(f'\nTraining MC_AE_EGNN model with k={k}')
        
        # Define transform to process eigenvectors and compute projectors
        transform = T.Compose([
            EigenvectorProcessor(k=k),
            T.NormalizeFeatures()
        ])
        
        # Load datasets
        train_dataset = MNISTSuperpixels(
            root=args.data_dir, 
            train=True,
            transform=transform
        )
        
        test_dataset = MNISTSuperpixels(
            root=args.data_dir, 
            train=False,
            transform=transform
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=args.batch_size, 
            shuffle=True
        )
        test_loader = DataLoader(
            test_dataset, 
            batch_size=args.batch_size
        )
        
        # Initialize MC_AE_EGNN model
        model = MC_AE_EGNN(
            in_node_nf=k,               # Input node features = eigenvectors dimension
            in_edge_nf=k,               # Input edge features = eigenspace projectors
            hidden_edge_nf=args.nf_edge, 
            hidden_node_nf=args.nf_node, 
            hidden_coord_nf=args.nf_coord,
            device=device, 
            n_layers=args.n_layers,
            K=k,                        # Number of eigenvalues
            with_coords=args.with_coords,
            recurrent=False, 
            norm_diff=args.norm_diff, 
            tanh=args.tanh, 
            num_vectors=args.num_vectors, 
            update_vel=args.update_vel
        ).to(device)
        
        # Initialize optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        
        # Initialize metrics storage
        k_results[k] = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': []
        }
        
        # Training loop
        best_val_acc = 0
        for epoch in range(1, args.epochs + 1):
            # Train
            train_loss = train(model, train_loader, optimizer, device)
            
            # Validate
            val_loss, val_acc = validate(model, test_loader, device)
            train_loss_val, train_acc = validate(model, train_loader, device)
            
            # Store metrics
            k_results[k]['train_loss'].append(train_loss_val)
            k_results[k]['val_loss'].append(val_loss)
            k_results[k]['train_acc'].append(train_acc)
            k_results[k]['val_acc'].append(val_acc)
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'best_mcae_model_k{k}.pt')
            
            print(f'Epoch: {epoch:02d}, '
                  f'Train Loss: {train_loss:.4f}, '
                  f'Val Loss: {val_loss:.4f}, '
                  f'Train Acc: {train_acc:.4f}, '
                  f'Val Acc: {val_acc:.4f}, '
                  f'Best Val Acc: {best_val_acc:.4f}')
    
    # Plot results
    plot_training_curves(k_results)
    
    # Print final results
    print('\nFinal Results:')
    for k in args.k_values:
        best_val_acc = max(k_results[k]['val_acc'])
        print(f'k={k}: Best Validation Accuracy: {best_val_acc:.4f}')

if __name__ == '__main__':
    main()
