import os
import argparse
import yaml
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

from data.dataset import get_dataloaders
from models.mihc import MIHC
from utils.evaluation import evaluate_predictions
from utils.visualization import visualize_congestion_map, visualize_bottleneck_subgraph


def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description='Test MIHC model for chip congestion prediction')

    parser.add_argument('--config', type=str, default='configs/config.yaml',
                        help='Path to configuration file')
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Path to model checkpoint')
    parser.add_argument('--results_dir', type=str, default='results',
                        help='Directory to save results')
    parser.add_argument('--visualize', action='store_true',
                        help='Generate visualizations')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use (cuda or cpu)')

    return parser.parse_args()


def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def load_checkpoint(checkpoint_path, device):
    """Load model checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    return checkpoint


def test(model, test_loader, device, results_dir, visualize=False):
    """Test the model and save results."""
    model.eval()

    cell_predictions = []
    cell_targets = []
    grid_predictions = []
    grid_targets = []

    design_names = []
    placement_names = []

    cell_probs_list = []
    grid_probs_list = []

    # Create directories for visualizations
    if visualize:
        vis_dir = os.path.join(results_dir, 'visualizations')
        congestion_dir = os.path.join(vis_dir, 'congestion_maps')
        bottleneck_dir = os.path.join(vis_dir, 'bottleneck_subgraphs')

        os.makedirs(congestion_dir, exist_ok=True)
        os.makedirs(bottleneck_dir, exist_ok=True)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc='Testing')):
            # Move batch to device
            cell_hypergraph = batch['cell_hypergraph'].to(device)
            grid_hypergraph = batch['grid_hypergraph'].to(device)
            cell_congestion = batch['cell_congestion'].to(device)
            grid_congestion = batch['grid_congestion'].to(device)

            batch_design_names = batch['design_names']
            batch_placement_names = batch['placement_names']

            # Forward pass
            predictions = model(cell_hypergraph, grid_hypergraph)

            # Collect predictions and targets
            cell_predictions.append(predictions['cell_congestion'].cpu())
            cell_targets.append(cell_congestion.cpu())
            grid_predictions.append(predictions['grid_congestion'].cpu())
            grid_targets.append(grid_congestion.cpu())

            # Collect design and placement names
            design_names.extend(batch_design_names)
            placement_names.extend(batch_placement_names)

            # Collect bottleneck probabilities
            if 'cell_probs' in predictions and predictions['cell_probs'] is not None:
                cell_probs_list.append(predictions['cell_probs'].cpu())
                grid_probs_list.append(predictions['grid_probs'].cpu())

            # Generate visualizations
            if visualize and batch_idx < 5:  # Limit to first 5 batches for visualization
                for i in range(len(batch_design_names)):
                    # Get design and placement names
                    design_name = batch_design_names[i]
                    placement_name = batch_placement_names[i]

                    # Get grid size from dataset config
                    grid_size = (64, 64)  # Default, adjust if needed

                    # Visualize congestion map
                    vis_path = os.path.join(congestion_dir, f"{design_name}_{placement_name}.png")
                    grid_pred = predictions['grid_congestion'][i].cpu().numpy()
                    grid_true = grid_congestion[i].cpu().numpy()

                    fig = visualize_congestion_map(
                        grid_true,
                        grid_size,
                        prediction=grid_pred,
                        title=f"{design_name} - {placement_name}",
                        save_path=vis_path
                    )
                    plt.close(fig)

                    # Visualize bottleneck subgraph if available
                    if 'cell_probs' in predictions and predictions['cell_probs'] is not None:
                        vis_path = os.path.join(bottleneck_dir, f"{design_name}_{placement_name}.png")
                        cell_probs = predictions['cell_probs'][i].cpu().numpy()
                        cell_pred = predictions['cell_congestion'][i].cpu().numpy()

                        # Get the corresponding cell hypergraph
                        cell_hyper = cell_hypergraph[i].to('cpu')

                        fig = visualize_bottleneck_subgraph(
                            cell_hyper,
                            cell_probs,
                            cell_congestion=cell_pred,
                            title=f"{design_name} - {placement_name} (Bottleneck Subgraph)",
                            save_path=vis_path
                        )
                        plt.close(fig)

    # Concatenate predictions and targets
    cell_predictions = torch.cat(cell_predictions)
    cell_targets = torch.cat(cell_targets)
    grid_predictions = torch.cat(grid_predictions)
    grid_targets = torch.cat(grid_targets)

    # Evaluate predictions
    cell_metrics = evaluate_predictions(cell_predictions, cell_targets)
    grid_metrics = evaluate_predictions(grid_predictions, grid_targets)

    # Concatenate bottleneck probabilities if available
    if cell_probs_list:
        cell_probs = torch.cat(cell_probs_list)
        grid_probs = torch.cat(grid_probs_list)
    else:
        cell_probs = None
        grid_probs = None

    # Print metrics
    print("\nTest Results:")
    print("Cell-based Metrics:")
    for metric_name, metric_value in cell_metrics.items():
        print(f"  {metric_name}: {metric_value:.4f}")

    print("\nGrid-based Metrics:")
    for metric_name, metric_value in grid_metrics.items():
        print(f"  {metric_name}: {metric_value:.4f}")

    # Save metrics to CSV
    metrics_df = pd.DataFrame({
        'Metric': list(cell_metrics.keys()) + list(grid_metrics.keys()),
        'Value': [cell_metrics[k] for k in cell_metrics] + [grid_metrics[k] for k in grid_metrics],
        'Type': ['Cell-based'] * len(cell_metrics) + ['Grid-based'] * len(grid_metrics)
    })

    metrics_path = os.path.join(results_dir, 'test_metrics.csv')
    metrics_df.to_csv(metrics_path, index=False)
    print(f"\nTest metrics saved to {metrics_path}")

    # Save predictions to CSV
    predictions_df = pd.DataFrame({
        'Design': design_names,
        'Placement': placement_names,
        'Cell_Prediction': [p.item() for p in cell_predictions],
        'Cell_Target': [t.item() for t in cell_targets],
        'Grid_Prediction': [p.item() for p in grid_predictions],
        'Grid_Target': [t.item() for t in grid_targets]
    })

    if cell_probs is not None:
        predictions_df['Cell_Bottleneck_Prob'] = [p.item() for p in cell_probs]
        predictions_df['Grid_Bottleneck_Prob'] = [p.item() for p in grid_probs]

    predictions_path = os.path.join(results_dir, 'test_predictions.csv')
    predictions_df.to_csv(predictions_path, index=False)
    print(f"Test predictions saved to {predictions_path}")

    return cell_metrics, grid_metrics


def main():
    """Main testing function."""
    # Parse arguments
    args = parse_args()

    # Load configuration
    config = load_config(args.config)

    # Set device
    device = torch.device(args.device)
    print(f"Using device: {device}")

    # Create results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join(args.results_dir, f"{timestamp}")
    os.makedirs(results_dir, exist_ok=True)

    # Create dataloaders
    dataloaders = get_dataloaders(config)
    test_loader = dataloaders['test']

    # Create model
    cell_feature_dim = next(iter(test_loader))['cell_hypergraph']['cell'].x.shape[1]
    grid_feature_dim = next(iter(test_loader))['grid_hypergraph']['grid'].x.shape[1]

    model = MIHC(
        cell_feature_dim=cell_feature_dim,
        grid_feature_dim=grid_feature_dim,
        hidden_dim=config['model']['hidden_dim'],
        num_layers=config['model']['mv_hgnn']['num_layers'],
        num_heads=config['model']['mv_hgnn']['num_attention_heads'],
        dropout=config['model']['dropout'],
        bottleneck_enable=config['model']['bottleneck']['enable'],
        temperature=config['model']['contrastive']['temperature'],
        beta=config['model']['bottleneck']['beta']
    ).to(device)

    # Load checkpoint
    checkpoint = load_checkpoint(args.checkpoint, device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch'] + 1}")

    # Test model
    cell_metrics, grid_metrics = test(model, test_loader, device, results_dir, args.visualize)

    # Save configuration and command-line arguments
    with open(os.path.join(results_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)

    with open(os.path.join(results_dir, 'args.txt'), 'w') as f:
        f.write(f"checkpoint: {args.checkpoint}\n")
        f.write(f"config: {args.config}\n")
        f.write(f"device: {args.device}\n")
        f.write(f"visualize: {args.visualize}\n")

    print("Testing completed!")


if __name__ == "__main__":
    main()