"""
Script to analyze the routing behavior of the trained model.
"""
import os
import sys
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from sklearn.preprocessing import StandardScaler
from matplotlib.patches import Rectangle, FancyArrowPatch
from scipy.cluster.hierarchy import linkage, leaves_list
import matplotlib.patches as mpatches

# Add the project root directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.utils.simulate_multitask_gp import simulate_multitask_dataset
from src.models.modality_router import RoutingModel


def plot_routing_heatmap(modality_probs, task_probs, save_path=None):
    """Plot heatmap of routing probabilities."""
    # Set style
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.titlesize': 18
    })
    
    # Get routing preferences
    modality_prefs = modality_probs.argmax(dim=1).numpy()
    task_probs = [p.numpy() for p in task_probs]  # Convert to numpy arrays
    
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(15, 6), dpi=300)
    
    # Plot modality routing heatmap
    sns.heatmap(
        pd.DataFrame(modality_probs.numpy()),
        ax=axes[0],
        cmap='viridis',
        cbar_kws={'label': 'Probability'}
    )
    axes[0].set_title('Modality Routing Probabilities')
    axes[0].set_xlabel('Modality Path')
    axes[0].set_ylabel('Sample')
    axes[0].set_xticklabels(['T1', 'N1', 'T2', 'N2'])
    
    # Plot task routing heatmap for each modality path
    task_probs_array = np.array(task_probs).transpose(1, 0, 2)  # Shape: (n_samples, n_modalities, 2)
    task_probs_stl = task_probs_array[..., 0]  # Take STL probabilities
    sns.heatmap(
        pd.DataFrame(task_probs_stl),
        ax=axes[1],
        cmap='viridis',
        cbar_kws={'label': 'Probability'}
    )
    axes[1].set_title('Task Routing Probabilities (STL)')
    axes[1].set_xlabel('Modality Path')
    axes[1].set_ylabel('Sample')
    axes[1].set_xticklabels(['T1', 'N1', 'T2', 'N2'])
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
        plt.close()
    else:
        plt.show()


def plot_routing_distribution(modality_probs, task_probs, save_path=None):
    """Plot distribution of routing decisions."""
    # Set publication-quality font sizes
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    # Get routing preferences
    modality_prefs = modality_probs.argmax(dim=1).numpy()
    task_probs = [p.numpy() for p in task_probs]  # Convert to numpy arrays
    task_probs_array = np.array(task_probs).transpose(1, 0, 2)  # Shape: (n_samples, n_modalities, 2)
    # For each sample, get the most probable modality, then the most probable task paradigm for that modality
    most_prob_modality = modality_prefs
    most_prob_task = [task_probs_array[i, m, :].argmax() for i, m in enumerate(most_prob_modality)]
    most_prob_task = np.array(most_prob_task)
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(15, 6), dpi=300)
    # Plot modality routing distribution
    sns.histplot(
        modality_prefs,
        ax=axes[0],
        discrete=True,
        color='#1f77b4',
        alpha=0.7
    )
    axes[0].set_title('Modality Routing Distribution', fontsize=22)
    axes[0].set_xlabel('Modality Path', fontsize=20)
    axes[0].set_ylabel('Count', fontsize=20)
    axes[0].set_xticks(range(4))
    axes[0].set_xticklabels(['T1', 'N1', 'T2', 'N2'], fontsize=18)
    # Plot task routing distribution (STL/MTL)
    sns.histplot(
        most_prob_task,
        ax=axes[1],
        discrete=True,
        color='#2ca02c',
        alpha=0.7
    )
    axes[1].set_title('Task Routing Distribution', fontsize=22)
    axes[1].set_xlabel('Task Paradigm', fontsize=20)
    axes[1].set_ylabel('Count', fontsize=20)
    axes[1].set_xticks([0, 1])
    axes[1].set_xticklabels(['STL', 'MTL'], fontsize=18)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
        plt.close()
    else:
        plt.show()


def plot_ground_truth_signals(signal_T2, signal_N2, signal_T1, signal_N1, save_path=None):
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    signals = [signal_T2, signal_N2, signal_T1, signal_N1]
    names = ['T2', 'N2', 'T1', 'N1']
    fig, axes = plt.subplots(1, 4, figsize=(14, 3), dpi=300)
    for i, (sig, name) in enumerate(zip(signals, names)):
        axes[i].hist(sig, bins=30, color='C'+str(i), alpha=0.7)
        axes[i].set_title(name)
        axes[i].set_xlabel('Signal')
        axes[i].set_ylabel('Count')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
        plt.close()
    else:
        plt.show()


def weighted_boxplot(ax, values, weights, labels, title, ylabel, no_outliers=False):
    # For each group, generate a synthetic sample according to weights
    synthetic_samples = []
    n_synth = 1000  # Number of synthetic samples per group
    for vals, wts in zip(values, weights):
        vals = np.array(vals)
        wts = np.array(wts)
        if np.sum(wts) == 0 or len(vals) == 0:
            synthetic_samples.append(np.array([np.nan]))
        else:
            wts = wts / np.sum(wts)
            synth = np.random.choice(vals, size=n_synth, p=wts)
            synthetic_samples.append(synth)
    ax.boxplot(synthetic_samples, labels=labels, patch_artist=True, showfliers=not no_outliers)
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.set_xlabel(labels[0] if len(labels) == 2 else 'Modality Path')
    ax.set_xticklabels(labels)


def analyze_predictions(predictions, targets, modality_probs, task_probs, save_path=None, output_dir=None):
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    pred1, pred2 = predictions
    target1, target2 = targets
    pred1 = pred1.cpu()
    pred2 = pred2.cpu()
    target1 = target1.reshape(-1, 1)
    target2 = target2.reshape(-1, 1)
    errors1 = (pred1 - target1[:pred1.shape[0]]).abs().numpy().squeeze()
    errors2 = (pred2 - target2[:pred2.shape[0]]).abs().numpy().squeeze()
    modality_prefs = modality_probs.argmax(dim=1).numpy()  # shape: (n_samples,)
    task_probs_np = [p.numpy() for p in task_probs]  # list of 4 arrays, each (n_samples, 2)
    task_probs_array = np.stack(task_probs_np, axis=1)  # shape: (n_samples, 4, 2)
    # For each sample, get the task probabilities for the selected modality
    task_probs_for_modality = task_probs_array[np.arange(len(modality_prefs)), modality_prefs, :]  # shape: (n_samples, 2)
    # For each sample, pick the most probable task paradigm (STL=0, MTL=1)
    task_prefs = task_probs_for_modality.argmax(axis=1)

    # Create DataFrames for robust plotting
    df = pd.DataFrame({
        'modality': modality_prefs,
        'task': task_prefs,
        'error1': errors1,
        'error2': errors2
    })
    # Print counts for each task paradigm
    print('Counts per task paradigm (hard routing):')
    print(df['task'].value_counts())
    # Ensure all modality and task categories are present and have at least one non-NaN value
    for m in range(4):
        if (df['modality'] == m).sum() == 0 or df.loc[df['modality'] == m, ['error1', 'error2']].isna().all().all():
            df = pd.concat([df, pd.DataFrame({'modality': [m], 'task': [0], 'error1': [np.nan], 'error2': [np.nan]})], ignore_index=True)
    for t in range(2):
        if (df['task'] == t).sum() == 0 or df.loc[df['task'] == t, ['error1', 'error2']].isna().all().all():
            df = pd.concat([df, pd.DataFrame({'modality': [0], 'task': [t], 'error1': [np.nan], 'error2': [np.nan]})], ignore_index=True)

    # Filter out empty modality groups for error1
    non_empty_modalities = [m for m in range(4) if not df.loc[df['modality'] == m, 'error1'].isna().all()]
    modality_labels = ['T1', 'N1', 'T2', 'N2']
    modality_labels_plot = [modality_labels[m] for m in non_empty_modalities]
    # Filter out empty task groups for error1
    non_empty_tasks = [t for t in range(2) if not df.loc[df['task'] == t, 'error1'].isna().all()]
    task_labels = ['STL', 'MTL']
    task_labels_plot = [task_labels[t] for t in non_empty_tasks]

    fig, axes = plt.subplots(1, 4, figsize=(14, 3), dpi=300)
    # Error by Modality (Task 1)
    sns.boxplot(x='modality', y='error1', data=df[df['modality'].isin(non_empty_modalities)], ax=axes[0], order=non_empty_modalities, palette='viridis')
    axes[0].set_title('Task 1 Error by Modality', fontsize=14)
    axes[0].set_xlabel('Modality Path')
    axes[0].set_ylabel('Abs Error')
    axes[0].set_xticklabels(modality_labels_plot)
    # Error by Modality (Task 2)
    sns.boxplot(x='modality', y='error2', data=df[df['modality'].isin(non_empty_modalities)], ax=axes[1], order=non_empty_modalities, palette='viridis')
    axes[1].set_title('Task 2 Error by Modality', fontsize=14)
    axes[1].set_xlabel('Modality Path')
    axes[1].set_ylabel('Abs Error')
    axes[1].set_xticklabels(modality_labels_plot)
    # Error by Task (Task 1)
    sns.boxplot(x='task', y='error1', data=df[df['task'].isin(non_empty_tasks)], ax=axes[2], order=non_empty_tasks, palette='viridis')
    axes[2].set_title('Task 1 Error by Task', fontsize=14)
    axes[2].set_xlabel('Task Paradigm')
    axes[2].set_ylabel('Abs Error')
    axes[2].set_xticklabels(task_labels_plot)
    # Error by Task (Task 2)
    sns.boxplot(x='task', y='error2', data=df[df['task'].isin(non_empty_tasks)], ax=axes[3], order=non_empty_tasks, palette='viridis')
    axes[3].set_title('Task 2 Error by Task', fontsize=14)
    axes[3].set_xlabel('Task Paradigm')
    axes[3].set_ylabel('Abs Error')
    axes[3].set_xticklabels(task_labels_plot)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
        plt.close()
    else:
        plt.show()


def analyze_predictions_soft(predictions, targets, modality_probs, task_probs, save_path=None, no_outliers=False):
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    pred1, pred2 = predictions
    target1, target2 = targets
    pred1 = pred1.cpu().numpy().squeeze()
    pred2 = pred2.cpu().numpy().squeeze()
    target1 = target1.cpu().numpy().reshape(-1)
    target2 = target2.cpu().numpy().reshape(-1)
    errors1 = np.abs(pred1 - target1)
    errors2 = np.abs(pred2 - target2)
    n = len(errors1)
    modality_probs_np = modality_probs.cpu().numpy()  # [n, 4]
    task_probs_np = np.stack([p.cpu().numpy() for p in task_probs], axis=1)  # [n, 4, 2]
    joint_probs = modality_probs_np[:, :, None] * task_probs_np  # [n, 4, 2]
    # For each modality, collect errors and weights
    soft_errors1 = [errors1 for _ in range(4)]
    soft_errors2 = [errors2 for _ in range(4)]
    soft_weights1 = [modality_probs_np[:, m] for m in range(4)]
    soft_weights2 = [modality_probs_np[:, m] for m in range(4)]
    # For each task, collect errors and weights
    soft_errors1_task = [errors1 for _ in range(2)]
    soft_errors2_task = [errors2 for _ in range(2)]
    soft_weights1_task = [joint_probs[:, :, t].sum(axis=1) for t in range(2)]
    soft_weights2_task = [joint_probs[:, :, t].sum(axis=1) for t in range(2)]
    # Print total weights for each task paradigm
    print('Total soft weights per task paradigm:')
    for t, label in enumerate(['STL', 'MTL']):
        print(f'{label}: {np.sum(soft_weights1_task[t]):.2f}')
    fig, axes = plt.subplots(1, 4, figsize=(14, 3), dpi=300)
    weighted_boxplot(axes[0], soft_errors1, soft_weights1, ['T1','N1','T2','N2'], 'Task 1 Error by Modality (Soft)', 'Abs Error', no_outliers=no_outliers)
    axes[0].set_title('Task 1 Error by Modality (Soft)', fontsize=14)
    weighted_boxplot(axes[1], soft_errors2, soft_weights2, ['T1','N1','T2','N2'], 'Task 2 Error by Modality (Soft)', 'Abs Error', no_outliers=no_outliers)
    axes[1].set_title('Task 2 Error by Modality (Soft)', fontsize=14)
    weighted_boxplot(axes[2], soft_errors1_task, soft_weights1_task, ['STL','MTL'], 'Task 1 Error by Task (Soft)', 'Abs Error', no_outliers=no_outliers)
    axes[2].set_title('Task 1 Error by Task (Soft)', fontsize=14)
    weighted_boxplot(axes[3], soft_errors2_task, soft_weights2_task, ['STL','MTL'], 'Task 2 Error by Task (Soft)', 'Abs Error', no_outliers=no_outliers)
    axes[3].set_title('Task 2 Error by Task (Soft)', fontsize=14)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
        plt.close()
    else:
        plt.show()


def debug_print_routing(modality_probs, task_probs):
    print('First 5 modality_probs (softmax outputs):')
    print(modality_probs[:5].numpy())
    print('First 5 modality argmax indices:')
    print(modality_probs.argmax(dim=1)[:5].numpy())
    print('First 5 task_probs (per modality, softmax outputs):')
    for i, p in enumerate(task_probs):
        print(f'Modality {i} task_probs (first 5):')
        print(p[:5].numpy())
    print('---')


def plot_joint_pmf_heatmap(modality_probs, task_probs, save_path=None):
    """Plot the average joint PMF over all samples as a heatmap (modalities x tasks)."""
    # modality_probs: [n_samples, 4]
    # task_probs: list of 4 tensors, each [n_samples, 2]
    n_samples = modality_probs.shape[0]
    joint_pmf = np.zeros((4, 2))
    for i in range(4):
        # For each modality, get P(modality=i) and P(task|modality=i)
        p_mod = modality_probs[:, i].cpu().numpy()  # [n_samples]
        p_task = task_probs[i].cpu().numpy()        # [n_samples, 2]
        # Average joint: sum over samples of P(modality=i) * P(task=j|modality=i)
        for j in range(2):
            joint_pmf[i, j] = np.mean(p_mod * p_task[:, j])
    # Normalize (should sum to 1)
    joint_pmf /= joint_pmf.sum()
    plt.figure(figsize=(6, 4), dpi=300)
    sns.heatmap(joint_pmf, annot=True, fmt='.3f', cmap='viridis', xticklabels=['STL', 'MTL'], yticklabels=['T1', 'N1', 'T2', 'N2'])
    plt.xlabel('Task Paradigm')
    plt.ylabel('Modality Path')
    plt.title('Average Joint Routing PMF (Soft)')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
        plt.close()
    else:
        plt.show()


def plot_routing_summary_subplot(modality_probs, task_probs, output_dir):
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    # Joint PMF
    n_samples = modality_probs.shape[0]
    joint_pmf = np.zeros((4, 2))
    for i in range(4):
        p_mod = modality_probs[:, i].cpu().numpy()
        p_task = task_probs[i].cpu().numpy()
        for j in range(2):
            joint_pmf[i, j] = np.mean(p_mod * p_task[:, j])
    joint_pmf /= joint_pmf.sum()
    # Clustered heatmap (no dendrogram)
    modality_probs_np = modality_probs.cpu().numpy()
    row_linkage = linkage(modality_probs_np, method='ward', metric='euclidean')
    row_order = leaves_list(row_linkage)
    modality_probs_clustered = modality_probs_np[row_order, :]
    # Sankey plot (drawn as an axis)
    fig = plt.figure(figsize=(24, 7), dpi=300)
    gs = fig.add_gridspec(1, 3, width_ratios=[1.1, 1.2, 1.7], wspace=0.25)
    # Joint PMF
    ax0 = fig.add_subplot(gs[0, 0])
    sns.heatmap(joint_pmf, annot=True, fmt='.2f', cmap='viridis', xticklabels=['STL', 'MTL'], yticklabels=['T1', 'N1', 'T2', 'N2'], ax=ax0, cbar=True, cbar_kws={'label': 'Probability'})
    ax0.set_xlabel('Task Paradigm')
    ax0.set_ylabel('Modality Path')
    ax0.set_title('Average Joint Routing PMF')
    # Clustered heatmap (middle)
    ax1 = fig.add_subplot(gs[0, 1])
    sns.heatmap(modality_probs_clustered, cmap='viridis', ax=ax1, cbar=True, cbar_kws={'label': 'Probability'})
    ax1.set_title('Modality Routing Probabilities')
    ax1.set_xlabel('Modality Path')
    ax1.set_ylabel('Sample (clustered)')
    ax1.set_xticklabels(['T1', 'N1', 'T2', 'N2'])
    ax1.set_yticks([])
    # Sankey (right)
    ax2 = fig.add_subplot(gs[0, 2])
    ax2.axis('off')
    mod_labels = ['T1', 'N1', 'T2', 'N2']
    task_labels = ['STL', 'MTL']
    x0, x1, x2 = 0.1, 0.4, 0.95
    mod_probs = joint_pmf.sum(axis=1)
    min_width = 0.01
    mod_probs_adj = np.maximum(mod_probs, min_width)
    mod_probs_adj /= mod_probs_adj.sum()
    y_mod_cum = np.cumsum(np.concatenate([[0], mod_probs_adj]))
    mod_y_centers = []
    for i in range(4):
        y0 = y_mod_cum[i]
        y1_ = y_mod_cum[i+1]
        yc = (y0 + y1_)/2
        mod_y_centers.append(yc)
        ax2.add_patch(Rectangle((x1-0.03, y0), 0.06, y1_-y0, color='C'+str(i), alpha=0.5))
        ax2.text(x1, yc, mod_labels[i], va='center', ha='center', fontsize=18)
        lw = max(20*(y1_-y0), 2)
        ax2.add_patch(FancyArrowPatch((x0, 0.5), (x1-0.03, yc), connectionstyle="arc3,rad=0.0", arrowstyle='-', linewidth=lw, color='C'+str(i), alpha=0.4))
    # Input node
    ax2.add_patch(Rectangle((x0-0.05, 0.4), 0.05, 0.2, color='gray', alpha=0.5))
    ax2.text(x0-0.08, 0.5, 'Input', va='center', ha='right', fontsize=22)
    # Sankey order: T1-STL, T1-MTL, N1-STL, N1-MTL, T2-STL, T2-MTL, N2-STL, N2-MTL
    sankey_order = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)]
    task_heights = np.maximum([joint_pmf[i,j] for (i,j) in sankey_order], min_width)
    task_heights = np.array(task_heights)
    task_heights /= task_heights.sum()
    spacing = 0.02
    y_task_cum = [0]
    for h in task_heights:
        y_task_cum.append(y_task_cum[-1] + h + spacing)
    y_task_cum = np.array(y_task_cum)
    y_task_cum = y_task_cum / y_task_cum[-1]
    label_offsets = np.linspace(-0.02, 0.02, 8)
    for idx, (i, j) in enumerate(sankey_order):
        y0 = y_task_cum[idx]
        y1_ = y_task_cum[idx+1] - spacing
        yc = (y0 + y1_)/2
        color = 'C'+str(i)
        ax2.add_patch(Rectangle((x2-0.03, y0), 0.06, y1_-y0, color=color, alpha=0.5))
        prob = joint_pmf[i, j]
        label_y = yc + label_offsets[idx]
        ax2.text(x2+0.08, label_y, f'{mod_labels[i]}-{task_labels[j]} ({prob:.2f})', va='center', ha='left', fontsize=15)
        lw = max(20*(y1_-y0), 2)
        # Find the correct modality center for this flow
        ax2.add_patch(FancyArrowPatch((x1+0.03, mod_y_centers[i]), (x2-0.03, yc), connectionstyle="arc3,rad=0.0", arrowstyle='-', linewidth=lw, color=color, alpha=0.7))
    ax2.set_xlim(0, 1.2)
    ax2.set_ylim(0, 1)
    ax2.set_title('Routing Diagram', fontsize=22)
    plt.savefig(os.path.join(output_dir, 'routing_summary_subplot.png'), dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.close()


def plot_routing_sankey(joint_pmf, output_dir):
    # Improved Sankey-like static plot
    fig, ax = plt.subplots(figsize=(16, 7))
    ax.axis('off')
    mod_labels = ['T1', 'N1', 'T2', 'N2']
    task_labels = ['STL', 'MTL']
    x0, x1, x2 = 0.1, 0.4, 0.95
    mod_probs = joint_pmf.sum(axis=1)
    min_width = 0.01
    mod_probs_adj = np.maximum(mod_probs, min_width)
    mod_probs_adj /= mod_probs_adj.sum()
    y_mod_cum = np.cumsum(np.concatenate([[0], mod_probs_adj]))
    # Draw input node
    ax.add_patch(Rectangle((x0-0.05, 0.4), 0.05, 0.2, color='gray', alpha=0.5))
    ax.text(x0-0.08, 0.5, 'Input', va='center', ha='right', fontsize=22)
    # Draw modality nodes and flows
    mod_y_centers = []
    for i in range(4):
        y0 = y_mod_cum[i]
        y1_ = y_mod_cum[i+1]
        yc = (y0 + y1_)/2
        mod_y_centers.append(yc)
        ax.add_patch(Rectangle((x1-0.03, y0), 0.06, y1_-y0, color='C'+str(i), alpha=0.5))
        ax.text(x1, yc, mod_labels[i], va='center', ha='center', fontsize=18)
        lw = max(20*(y1_-y0), 2)
        ax.add_patch(FancyArrowPatch((x0, 0.5), (x1-0.03, yc), connectionstyle="arc3,rad=0.0", arrowstyle='-', linewidth=lw, color='C'+str(i), alpha=0.4))
    # From modality to STL/MTL
    task_heights = np.maximum(joint_pmf.flatten(), min_width)
    task_heights /= task_heights.sum()
    spacing = 0.02
    y_task_cum = [0]
    for h in task_heights:
        y_task_cum.append(y_task_cum[-1] + h + spacing)
    y_task_cum = np.array(y_task_cum)
    y_task_cum = y_task_cum / y_task_cum[-1]  # Normalize to [0,1]
    # Draw task nodes and flows
    idx = 0
    label_offsets = np.linspace(-0.02, 0.02, 8)  # Stagger labels vertically
    for i in range(4):
        for j in range(2):
            y0 = y_task_cum[idx]
            y1_ = y_task_cum[idx+1] - spacing
            yc = (y0 + y1_)/2
            color = 'C'+str(i)
            ax.add_patch(Rectangle((x2-0.03, y0), 0.06, y1_-y0, color=color, alpha=0.5))
            prob = joint_pmf[i, j]
            # Stagger label
            label_y = yc + label_offsets[idx]
            ax.text(x2+0.08, label_y, f'{mod_labels[i]}-{task_labels[j]} ({prob:.2f})', va='center', ha='left', fontsize=15)
            lw = max(20*(y1_-y0), 2)
            ax.add_patch(FancyArrowPatch((x1+0.03, mod_y_centers[i]), (x2-0.03, yc), connectionstyle="arc3,rad=0.0", arrowstyle='-', linewidth=lw, color=color, alpha=0.7))
            idx += 1
    ax.set_xlim(0, 1.2)
    ax.set_ylim(0, 1)
    plt.title('Routing Diagram', fontsize=26)
    plt.savefig(os.path.join(output_dir, 'routing_diagram.png'), bbox_inches='tight', dpi=300)
    plt.close()


def plot_clustered_routing_heatmap(modality_probs, output_dir):
    modality_probs_np = modality_probs.cpu().numpy()
    g = sns.clustermap(
        modality_probs_np,
        method='ward',
        metric='euclidean',
        cmap='viridis',
        figsize=(10, 8),
        yticklabels=False,
        xticklabels=['T1', 'N1', 'T2', 'N2'],
        cbar_kws={'label': 'Probability'}
    )
    plt.suptitle('Modality Routing Probabilities', fontsize=20)
    plt.savefig(os.path.join(output_dir, 'routing_heatmap.png'), bbox_inches='tight', dpi=300)
    plt.close()


def plot_paired_routing_error_boxplot(predictions, targets, modality_probs, task_probs, output_dir):
    import matplotlib.pyplot as plt
    import numpy as np
    import matplotlib.patches as mpatches
    import os
    # Set publication-quality font sizes
    plt.rcParams.update({
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 18,
        'figure.titlesize': 24
    })
    # Prepare data
    pred1, pred2 = predictions
    target1, target2 = targets
    pred1 = pred1.cpu().numpy().squeeze()
    pred2 = pred2.cpu().numpy().squeeze()
    target1 = target1.cpu().numpy().reshape(-1)
    target2 = target2.cpu().numpy().reshape(-1)
    errors1 = np.abs(pred1 - target1)
    errors2 = np.abs(pred2 - target2)
    modality_probs_np = modality_probs.cpu().numpy()  # [n, 4]
    task_probs_np = np.stack([p.cpu().numpy() for p in task_probs], axis=1)  # [n, 4, 2]
    joint_probs = modality_probs_np[:, :, None] * task_probs_np  # [n, 4, 2]
    # Hard routing assignments
    hard_modality = modality_probs_np.argmax(axis=1)
    hard_task = np.array([task_probs_np[i, m, :].argmax() for i, m in enumerate(hard_modality)])
    # For each route (modality, task), collect hard and soft errors
    mod_labels = ['T1', 'N1', 'T2', 'N2']
    task_labels = ['STL', 'MTL']
    data = []
    for i in range(4):
        for j in range(2):
            # Hard: only samples routed to (i, j)
            mask = (hard_modality == i) & (hard_task == j)
            hard_err1 = errors1[mask]
            hard_err2 = errors2[mask]
            # Soft: all samples, weighted by joint_probs[:, i, j]
            soft_weights = joint_probs[:, i, j]
            if np.sum(soft_weights) > 0:
                soft_err1 = np.random.choice(errors1, size=1000, p=soft_weights/soft_weights.sum())
                soft_err2 = np.random.choice(errors2, size=1000, p=soft_weights/soft_weights.sum())
            else:
                soft_err1 = np.array([])
                soft_err2 = np.array([])
            data.append({
                'route': f'{mod_labels[i]}-{task_labels[j]}',
                'hard_err1': hard_err1,
                'soft_err1': soft_err1,
                'hard_err2': hard_err2,
                'soft_err2': soft_err2
            })
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(24, 7), dpi=300)
    box_colors = ['#1f77b4', '#ff7f0e']  # blue for hard, orange for soft
    n_routes = len(data)
    width = 0.35
    gap = 0.22
    for k, (err_key, task_name) in enumerate([('err1', 'Task 1'), ('err2', 'Task 2')]):
        box_data = []
        box_positions = []
        box_labels = []
        pos = 1
        for idx, d in enumerate(data):
            # Hard
            box_data.append(d[f'hard_{err_key}'])
            box_positions.append(pos)
            box_labels.append(f"{d['route']}\nHard")
            # Soft
            box_data.append(d[f'soft_{err_key}'])
            box_positions.append(pos+width)
            box_labels.append(f"{d['route']}\nSoft")
            pos += 2*width + gap
        bplots = axes[k].boxplot(box_data, positions=box_positions, widths=width, patch_artist=True, showfliers=False)
        # Color boxes
        for i, patch in enumerate(bplots['boxes']):
            patch.set_facecolor(box_colors[i%2])
        # X-ticks
        xtick_pos = [np.mean([box_positions[2*i], box_positions[2*i+1]]) for i in range(n_routes)]
        xtick_labels = [data[i]['route'] for i in range(n_routes)]
        axes[k].set_xticks(xtick_pos)
        axes[k].set_xticklabels(xtick_labels, rotation=30, ha='right', fontsize=20)
        axes[k].set_ylabel('Abs Error', fontsize=22)
        axes[k].set_title(f'{task_name} Error by Route (Hard, Soft)', fontsize=24)
        # Add legend
        hard_patch = mpatches.Patch(color=box_colors[0], label='Hard')
        soft_patch = mpatches.Patch(color=box_colors[1], label='Soft')
        axes[k].legend(handles=[hard_patch, soft_patch], loc='upper right', fontsize=20)
        # Add vertical lines to separate modalities
        for i in range(1, 4):
            axes[k].axvline(x=2*i*(width+gap)-gap/2, color='gray', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'paired_routing_error_boxplot.png'), bbox_inches='tight', dpi=300)
    plt.close()


def train_no_routing_baseline_all_routes(X_numeric, X_textual, y1, y2, output_dir, device):
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader, TensorDataset
    import os
    mod_labels = ['T1', 'N1', 'T2', 'N2']
    task_labels = ['STL', 'MTL']
    preds_all = {}
    for i in range(4):
        for j in range(2):
            class NoRoutingNet(nn.Module):
                def __init__(self, input_dim_numeric, input_dim_text, hidden_dim=32):
                    super().__init__()
                    self.fc1 = nn.Linear(input_dim_numeric + input_dim_text, hidden_dim)
                    self.relu = nn.ReLU()
                    self.fc2 = nn.Linear(hidden_dim, 2)
                def forward(self, x_num, x_text):
                    x = torch.cat([x_num, x_text], dim=1)
                    x = self.relu(self.fc1(x))
                    return self.fc2(x)
            dataset = TensorDataset(X_numeric, X_textual, y1, y2)
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
            model = NoRoutingNet(X_numeric.shape[1], X_textual.shape[1]).to(device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            criterion = nn.MSELoss()
            model.train()
            for epoch in range(30):
                for xb_num, xb_text, yb1, yb2 in dataloader:
                    xb_num, xb_text = xb_num.to(device), xb_text.to(device)
                    yb = torch.cat([yb1, yb2], dim=1).to(device)
                    pred = model(xb_num, xb_text)
                    loss = criterion(pred, yb)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            model.eval()
            with torch.no_grad():
                pred = model(X_numeric.to(device), X_textual.to(device)).cpu().numpy()
            pred1 = pred[:, 0]
            pred2 = pred[:, 1]
            route = f'{mod_labels[i]}-{task_labels[j]}'
            preds_all[route] = (pred1, pred2)
            np.savez(os.path.join(output_dir, f'no_routing_baseline_preds_{route}.npz'), pred1=pred1, pred2=pred2)
    return preds_all


def parse_args():
    parser = argparse.ArgumentParser(description='Analyze routing behavior')
    parser.add_argument('--model_path', type=str, required=True,
                      help='Path to trained model checkpoint')
    parser.add_argument('--n_samples', type=int, default=1000,
                      help='Number of samples to generate for analysis')
    parser.add_argument('--batch_size', type=int, default=32,
                      help='Batch size')
    parser.add_argument('--output_dir', type=str, default='experiments/results',
                      help='Directory to save analysis plots')
    return parser.parse_args()


def main():
    """Main function to analyze routing behavior."""
    args = parse_args()
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generate data
    data = simulate_multitask_dataset(
        n_samples=args.n_samples
    )
    X_numeric = data['X_numeric']
    X_textual = data['X_text']
    y1 = data['y1']
    y2 = data['y2']

    # Normalize X_numeric and X_textual
    scaler_numeric = StandardScaler()
    scaler_textual = StandardScaler()
    X_numeric = scaler_numeric.fit_transform(X_numeric)
    X_textual = scaler_textual.fit_transform(X_textual)

    # Convert to torch tensors
    X_numeric = torch.FloatTensor(X_numeric)
    X_textual = torch.FloatTensor(X_textual)
    y1 = torch.FloatTensor(y1).reshape(-1, 1)
    y2 = torch.FloatTensor(y2).reshape(-1, 1)
    
    # Create dataset and dataloader
    dataset = TensorDataset(X_numeric, X_textual, y1, y2)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
    
    # Load model
    checkpoint = torch.load(args.model_path)
    model = RoutingModel(
        input_dim_numeric=X_numeric.shape[1],
        input_dim_text=X_textual.shape[1],
        hidden_dim=32
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Collect predictions and routing probabilities
    all_preds1 = []
    all_preds2 = []
    all_modality_probs = []
    all_task_probs = []
    
    with torch.no_grad():
        for batch in dataloader:
            X_numeric_batch, X_textual_batch, _, _ = batch
            X_numeric_batch = X_numeric_batch.to(device)
            X_textual_batch = X_textual_batch.to(device)
            
            preds1, preds2, modality_probs, task_probs = model(
                X_numeric_batch, X_textual_batch
            )
            
            all_preds1.append(preds1.cpu())
            all_preds2.append(preds2.cpu())
            all_modality_probs.append(modality_probs.cpu())
            all_task_probs.append([p.cpu() for p in task_probs])
    
    # Concatenate predictions and probabilities
    predictions = (
        torch.cat(all_preds1),
        torch.cat(all_preds2)
    )
    modality_probs = torch.cat(all_modality_probs)
    task_probs = [torch.cat([batch[i] for batch in all_task_probs]) for i in range(len(all_task_probs[0]))]
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Generate plots
    plot_routing_heatmap(
        modality_probs,
        task_probs,
        save_path=os.path.join(args.output_dir, 'routing_heatmap.png')
    )
    
    plot_routing_distribution(
        modality_probs,
        task_probs,
        save_path=os.path.join(args.output_dir, 'routing_distribution.png')
    )
    
    analyze_predictions(
        predictions,
        (y1, y2),
        modality_probs,
        task_probs,
        save_path=os.path.join(args.output_dir, 'prediction_analysis.png'),
        output_dir=args.output_dir
    )

    analyze_predictions_soft(
        predictions,
        (y1, y2),
        modality_probs,
        task_probs,
        save_path=os.path.join(args.output_dir, 'prediction_analysis_soft.png')
    )

    analyze_predictions_soft(
        predictions,
        (y1, y2),
        modality_probs,
        task_probs,
        save_path=os.path.join(args.output_dir, 'prediction_analysis_soft_no_outliers.png'),
        no_outliers=True
    )

    debug_print_routing(modality_probs, task_probs)
    plot_joint_pmf_heatmap(modality_probs, task_probs, save_path=os.path.join(args.output_dir, 'joint_pmf_heatmap.png'))

    plot_routing_summary_subplot(modality_probs, task_probs, args.output_dir)

    # Compute joint_pmf before calling plot_routing_sankey
    joint_pmf = np.zeros((4, 2))
    for i in range(4):
        # For each modality, get P(modality=i) and P(task|modality=i)
        p_mod = modality_probs[:, i].cpu().numpy()  # [n_samples]
        p_task = task_probs[i].cpu().numpy()        # [n_samples, 2]
        # Average joint: sum over samples of P(modality=i) * P(task=j|modality=i)
        for j in range(2):
            joint_pmf[i, j] = np.mean(p_mod * p_task[:, j])
    # Normalize (should sum to 1)
    joint_pmf /= joint_pmf.sum()

    plot_routing_sankey(joint_pmf, args.output_dir)

    plot_clustered_routing_heatmap(modality_probs, args.output_dir)

    plot_paired_routing_error_boxplot(predictions, (y1, y2), modality_probs, task_probs, args.output_dir)

    # Train all no-routing baselines and use their predictions in the paired boxplot
    preds_all = train_no_routing_baseline_all_routes(X_numeric, X_textual, y1, y2, args.output_dir, device)


if __name__ == '__main__':
    main() 