import os
import argparse
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from data.dataset import get_dataloaders
from models.mihc import MIHC
from utils.evaluation import evaluate_predictions
from utils.visualization import visualize_training_curves


def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description='Train MIHC model for chip congestion prediction')

    parser.add_argument('--config', type=str, default='configs/config.yaml',
                        help='Path to configuration file')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints',
                        help='Directory to save checkpoints')
    parser.add_argument('--log_dir', type=str, default='logs',
                        help='Directory to save logs')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use (cuda or cpu)')

    return parser.parse_args()


def load_config(config_path):
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def save_checkpoint(model, optimizer, epoch, metrics, checkpoint_path):
    """Save model checkpoint."""
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics
    }

    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")


def train_epoch(model, train_loader, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    total_loss = 0.0

    cell_predictions = []
    cell_targets = []
    grid_predictions = []
    grid_targets = []

    progress_bar = tqdm(train_loader, desc='Training')

    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        cell_hypergraph = batch['cell_hypergraph'].to(device)
        grid_hypergraph = batch['grid_hypergraph'].to(device)
        cell_congestion = batch['cell_congestion'].to(device)
        grid_congestion = batch['grid_congestion'].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        predictions = model(cell_hypergraph, grid_hypergraph)

        # Compute loss
        targets = {
            'cell_congestion': cell_congestion,
            'grid_congestion': grid_congestion
        }

        loss_dict = model.compute_loss(predictions, targets)
        loss = loss_dict['total_loss']

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update metrics
        total_loss += loss.item()

        # Collect predictions for evaluation
        cell_predictions.append(predictions['cell_congestion'].detach().cpu())
        cell_targets.append(cell_congestion.detach().cpu())
        grid_predictions.append(predictions['grid_congestion'].detach().cpu())
        grid_targets.append(grid_congestion.detach().cpu())

        # Update progress bar
        progress_bar.set_postfix({
            'loss': loss.item(),
            'sup_loss': loss_dict['supervision_loss'].item(),
            'ib_loss': loss_dict['ib_loss'].item(),
            'cont_loss': loss_dict['contrastive_loss'].item()
        })

    # Concatenate predictions and targets
    cell_predictions = torch.cat(cell_predictions)
    cell_targets = torch.cat(cell_targets)
    grid_predictions = torch.cat(grid_predictions)
    grid_targets = torch.cat(grid_targets)

    # Evaluate predictions
    cell_metrics = evaluate_predictions(cell_predictions, cell_targets)
    grid_metrics = evaluate_predictions(grid_predictions, grid_targets)

    # Compute average loss
    avg_loss = total_loss / len(train_loader)

    return avg_loss, cell_metrics, grid_metrics


def validate(model, val_loader, device):
    """Validate the model."""
    model.eval()
    total_loss = 0.0

    cell_predictions = []
    cell_targets = []
    grid_predictions = []
    grid_targets = []

    with torch.no_grad():
        for batch in val_loader:
            # Move batch to device
            cell_hypergraph = batch['cell_hypergraph'].to(device)
            grid_hypergraph = batch['grid_hypergraph'].to(device)
            cell_congestion = batch['cell_congestion'].to(device)
            grid_congestion = batch['grid_congestion'].to(device)

            # Forward pass
            predictions = model(cell_hypergraph, grid_hypergraph)

            # Compute loss
            targets = {
                'cell_congestion': cell_congestion,
                'grid_congestion': grid_congestion
            }

            loss_dict = model.compute_loss(predictions, targets)
            loss = loss_dict['total_loss']

            # Update metrics
            total_loss += loss.item()

            # Collect predictions for evaluation
            cell_predictions.append(predictions['cell_congestion'].cpu())
            cell_targets.append(cell_congestion.cpu())
            grid_predictions.append(predictions['grid_congestion'].cpu())
            grid_targets.append(grid_congestion.cpu())

    # Concatenate predictions and targets
    cell_predictions = torch.cat(cell_predictions)
    cell_targets = torch.cat(cell_targets)
    grid_predictions = torch.cat(grid_predictions)
    grid_targets = torch.cat(grid_targets)

    # Evaluate predictions
    cell_metrics = evaluate_predictions(cell_predictions, cell_targets)
    grid_metrics = evaluate_predictions(grid_predictions, grid_targets)

    # Compute average loss
    avg_loss = total_loss / len(val_loader)

    return avg_loss, cell_metrics, grid_metrics


def main():
    """Main training function."""
    # Parse arguments
    args = parse_args()

    # Load configuration
    config = load_config(args.config)

    # Set device
    device = torch.device(args.device)
    print(f"Using device: {device}")

    # Create dataloaders
    dataloaders = get_dataloaders(config)
    train_loader = dataloaders['train']
    val_loader = dataloaders['val']

    # Create model
    cell_feature_dim = next(iter(train_loader))['cell_hypergraph']['cell'].x.shape[1]
    grid_feature_dim = next(iter(train_loader))['grid_hypergraph']['grid'].x.shape[1]

    model = MIHC(
        cell_feature_dim=cell_feature_dim,
        grid_feature_dim=grid_feature_dim,
        hidden_dim=config['model']['hidden_dim'],
        num_layers=config['model']['mv_hgnn']['num_layers'],
        num_heads=config['model']['mv_hgnn']['num_attention_heads'],
        dropout=config['model']['dropout'],
        bottleneck_enable=config['model']['bottleneck']['enable'],
        temperature=config['model']['contrastive']['temperature'],
        beta=config['model']['bottleneck']['beta']
    ).to(device)

    # Create optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )

    # Create learning rate scheduler
    if config['training']['lr_scheduler']['enable']:
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['training']['lr_scheduler']['step_size'],
            gamma=config['training']['lr_scheduler']['gamma']
        )
    else:
        scheduler = None

    # Create tensorboard writer
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join(args.log_dir, f"{timestamp}")
    writer = SummaryWriter(log_dir)

    # Create checkpoint directory
    checkpoint_dir = os.path.join(args.checkpoint_dir, f"{timestamp}")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Training loop
    num_epochs = config['training']['num_epochs']
    best_val_loss = float('inf')
    patience = config['training']['early_stopping']['patience']
    min_delta = config['training']['early_stopping']['min_delta']
    counter = 0

    metrics_history = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'train_cell_nmae': [],
        'train_cell_nrms': [],
        'train_cell_pearson': [],
        'train_cell_spearman': [],
        'train_cell_kendall': [],
        'train_grid_nmae': [],
        'train_grid_nrms': [],
        'train_grid_pearson': [],
        'train_grid_spearman': [],
        'train_grid_kendall': [],
        'val_cell_nmae': [],
        'val_cell_nrms': [],
        'val_cell_pearson': [],
        'val_cell_spearman': [],
        'val_cell_kendall': [],
        'val_grid_nmae': [],
        'val_grid_nrms': [],
        'val_grid_pearson': [],
        'val_grid_spearman': [],
        'val_grid_kendall': []
    }

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss, train_cell_metrics, train_grid_metrics = train_epoch(model, train_loader, optimizer, device)

        # Validate
        val_loss, val_cell_metrics, val_grid_metrics = validate(model, val_loader, device)

        # Update learning rate
        if scheduler is not None:
            scheduler.step()

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Train Cell NMAE: {train_cell_metrics['nmae']:.4f}, Val Cell NMAE: {val_cell_metrics['nmae']:.4f}")
        print(f"Train Grid NMAE: {train_grid_metrics['nmae']:.4f}, Val Grid NMAE: {val_grid_metrics['nmae']:.4f}")

        # Log metrics to tensorboard
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)

        for metric_name, metric_value in train_cell_metrics.items():
            writer.add_scalar(f'Cell/{metric_name}/train', metric_value, epoch)

        for metric_name, metric_value in val_cell_metrics.items():
            writer.add_scalar(f'Cell/{metric_name}/val', metric_value, epoch)

        for metric_name, metric_value in train_grid_metrics.items():
            writer.add_scalar(f'Grid/{metric_name}/train', metric_value, epoch)

        for metric_name, metric_value in val_grid_metrics.items():
            writer.add_scalar(f'Grid/{metric_name}/val', metric_value, epoch)

        # Update metrics history
        metrics_history['epoch'].append(epoch + 1)
        metrics_history['train_loss'].append(train_loss)
        metrics_history['val_loss'].append(val_loss)
        metrics_history['train_cell_nmae'].append(train_cell_metrics['nmae'])
        metrics_history['train_cell_nrms'].append(train_cell_metrics['nrms'])
        metrics_history['train_cell_pearson'].append(train_cell_metrics['pearson'])
        metrics_history['train_cell_spearman'].append(train_cell_metrics['spearman'])
        metrics_history['train_cell_kendall'].append(train_cell_metrics['kendall'])
        metrics_history['train_grid_nmae'].append(train_grid_metrics['nmae'])
        metrics_history['train_grid_nrms'].append(train_grid_metrics['nrms'])
        metrics_history['train_grid_pearson'].append(train_grid_metrics['pearson'])
        metrics_history['train_grid_spearman'].append(train_grid_metrics['spearman'])
        metrics_history['train_grid_kendall'].append(train_grid_metrics['kendall'])
        metrics_history['val_cell_nmae'].append(val_cell_metrics['nmae'])
        metrics_history['val_cell_nrms'].append(val_cell_metrics['nrms'])
        metrics_history['val_cell_pearson'].append(val_cell_metrics['pearson'])
        metrics_history['val_cell_spearman'].append(val_cell_metrics['spearman'])
        metrics_history['val_cell_kendall'].append(val_cell_metrics['kendall'])
        metrics_history['val_grid_nmae'].append(val_grid_metrics['nmae'])
        metrics_history['val_grid_nrms'].append(val_grid_metrics['nrms'])
        metrics_history['val_grid_pearson'].append(val_grid_metrics['pearson'])
        metrics_history['val_grid_spearman'].append(val_grid_metrics['spearman'])
        metrics_history['val_grid_kendall'].append(val_grid_metrics['kendall'])

        # Save checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
        save_checkpoint(model, optimizer, epoch, {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_cell_metrics': train_cell_metrics,
            'train_grid_metrics': train_grid_metrics,
            'val_cell_metrics': val_cell_metrics,
            'val_grid_metrics': val_grid_metrics
        }, checkpoint_path)

        # Save best model
        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            counter = 0
            best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
            save_checkpoint(model, optimizer, epoch, {
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_cell_metrics': train_cell_metrics,
                'train_grid_metrics': train_grid_metrics,
                'val_cell_metrics': val_cell_metrics,
                'val_grid_metrics': val_grid_metrics
            }, best_model_path)
            print(f"New best model saved with validation loss: {val_loss:.4f}")
        else:
            counter += 1
            print(f"EarlyStopping counter: {counter} out of {patience}")
            if counter >= patience:
                print("Early stopping")
                break

    # Visualize training curves
    fig = visualize_training_curves(metrics_history, title="Training Curves")
    fig.savefig(os.path.join(log_dir, "training_curves.png"), dpi=300, bbox_inches='tight')
    plt.close(fig)

    # Close tensorboard writer
    writer.close()

    print("Training completed!")


if __name__ == "__main__":
    main()