# python run_uncertainty_analysis.py --data_path /home/haohuawang/blmc/data/half-circle.pkl

import os
import sys
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import pickle
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json

# Add the parent directory to the path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from engine.inits import get_algorithm, get_original_dataset
# from engine.hparams_registry import get_hparams
# from engine.configs import configs
from datasets.fast_data_loader import FastDataLoader
# import network
import datasets

# from engine.utils.visualizer import Visualizer
from engine.utils.checkpointer import Checkpointer
from engine.utils.checkpointer import remove_modules_for_DataParallel, add_modules_for_DataParallel


class UncertaintyAnalyzer:
    """
    Analyzer for calculating uncertainty in BayesShift model using Bayesian statistics.
    Calculates p(y|x) using posterior sampling and quantifies uncertainty.
    """
    
    def __init__(self, model_path, data_path, device='cuda'):
        self.device = device
        self.model_path = model_path
        self.data_path = data_path
        
        # Load model and data
        self.model = self._load_model()
        self.dataset = self._load_dataset()
        
        # MCMC sampling parameters
        self.num_samples = 100
        self.num_mcmc_steps = 10
        self.step_size = 0.1
    
    def _safe_to_numpy(self, tensor):
        """
        Safely convert a tensor to numpy array, handling gradients.
        """
        if tensor is None:
            return None
        if tensor.requires_grad:
            return tensor.detach().cpu().numpy()
        else:
            return tensor.cpu().numpy()
    
    def _load_model(self):
        """Load the trained BayesShift model from checkpoint."""
        # Create a dummy config for model initialization
        class DummyConfig:
            def __init__(self):
                self.data_name = 'ToyCircle'
                # self.data_path = self.data_path  # TODO
                self.data_path = '/home/haohuawang/blmc/data/half-circle-cs.pkl'
                self.num_classes = 2
                self.data_size = [2]  # ToyCircle has 2D data
                self.source_domains = 15
                self.intermediate_domains = 5
                self.target_domains = 10
                self.model_func = 'Toy_Linear_FE'
                self.feature_dim = 512  # 128
                self.cla_func = 'Linear_Cla'
                self.algorithm = 'bayesshift'
                self.zc_dim = 20
                self.zw_dim = 20
                self.seed = 0
                self.gpu_ids = '0'
        
        config = DummyConfig()
        
        # Initialize model
        model = get_algorithm(config)
        model.to(self.device)
        
        # Load checkpoint
        # self.checkpointer = Checkpointer(self.model_path, model, config.seed)
        # epoch = self.checkpointer.load_model_from_path('/home/haohuawang/blmc/LSSAE/logs/ToyCircle_C/BayesShift/ckpt/best_seed-0.pth.tar')
        checkpoint = torch.load(self.model_path, map_location=self.device)
        epoch = checkpoint['epoch_index']
        pretrained_algorithm_dict = checkpoint['algorithm']
        # print(checkpoint)
        # model.load_state_dict(checkpoint['model_state_dict'])
        # model_dict = model.state_dict()
        if not isinstance(model, nn.DataParallel) and 'module.' in list(pretrained_algorithm_dict.keys())[0]:
            pretrained_algorithm_dict = remove_modules_for_DataParallel(pretrained_algorithm_dict)

        if isinstance(model, nn.DataParallel) and 'module.' not in list(pretrained_algorithm_dict.keys())[0]:
            pretrained_algorithm_dict = add_modules_for_DataParallel(pretrained_algorithm_dict)

        model_dict = model.state_dict()

        # # 1. filter out unnecessary keys
        pretrained_algorithm_dict = {k: v for k, v in pretrained_algorithm_dict.items() if k in model_dict}

        model_dict.update(pretrained_algorithm_dict)
        model.load_state_dict(model_dict)
        model.eval()
        
        return model
    
    def _load_dataset(self):
        """Load the ToyCircle dataset."""
        class DummyConfig:
            def __init__(self):
                self.data_name = 'ToyCircle'
                # self.data_path = self.data_path
                self.data_path = '/home/haohuawang/blmc/data/half-circle-cs.pkl'
                self.num_classes = 2
                self.data_size = [2]
                self.source_domains = 15
                self.intermediate_domains = 5
                self.target_domains = 10
                self.seed = 0
        
        config = DummyConfig()
        return get_original_dataset(config)
    
    def _safe_sampling(self, encoder, batch_size, expected_shape):
        """
        Safely sample from an encoder, ensuring consistent output shapes.
        Handles domain dimension issues in the original sampling methods.
        """
        try:
            # Get the raw sample
            raw_sample = encoder.sampling(batch_size=batch_size)
            
            # Handle different possible shapes
            if len(raw_sample.shape) == 2:
                # Already correct shape [batch_size, latent_dim]
                if raw_sample.shape == expected_shape:
                    return raw_sample
                else:
                    print(f"  Warning: Raw sample shape {raw_sample.shape} != expected {expected_shape}")
                    # Try to reshape if numel matches
                    if raw_sample.numel() == expected_shape[0] * expected_shape[1]:
                        return raw_sample.reshape(expected_shape)
                    else:
                        raise ValueError(f"Cannot reshape {raw_sample.shape} to {expected_shape}")
            
            elif len(raw_sample.shape) == 3:
                # Has domain dimension [batch_size, domains, latent_dim]
                if raw_sample.shape[1] == 1:
                    # Single domain, just squeeze
                    squeezed = raw_sample.squeeze(1)
                    if squeezed.shape == expected_shape:
                        return squeezed
                    else:
                        raise ValueError(f"Squeezed shape {squeezed.shape} != expected {expected_shape}")
                else:
                    # Multiple domains, take the first one
                    first_domain = raw_sample[:, 0, :]
                    if first_domain.shape == expected_shape:
                        return first_domain
                    else:
                        raise ValueError(f"First domain shape {first_domain.shape} != expected {expected_shape}")
            
            else:
                raise ValueError(f"Unexpected raw sample shape: {raw_sample.shape}")
                
        except Exception as e:
            print(f"  Error in safe sampling: {e}")
            # Return a fallback sample with correct shape
            return torch.randn(expected_shape).to(self.device)
    
    def sample_posterior(self, x, domain_idx, num_samples=100):
        """
        Sample from the posterior p(z|x) using multiple forward passes.
        Returns samples from the latent space.
        """
        self.model.eval()
        batch_size = x.shape[0]
        
        print(f"Debug: Sampling posterior for batch_size={batch_size}, domain_idx={domain_idx}")
        
        # Get initial samples from the encoder using safe sampling
        with torch.no_grad():
            try:
                # Static encoder
                _ = self.model.static_encoder(x.unsqueeze(1))
                zc_init = self._safe_sampling(
                    self.model.static_encoder, 
                    batch_size, 
                    (batch_size, self.model.zc_dim)
                )
                print(f"Debug: zc_init shape: {zc_init.shape}, numel: {zc_init.numel()}")
                print(f"Debug: Expected zc shape: [{batch_size}, {self.model.zc_dim}], numel: {batch_size * self.model.zc_dim}")
                
                # Dynamic encoders
                _ = self.model.dynamic_w_encoder(x.unsqueeze(1), None)
                zw_init = self._safe_sampling(
                    self.model.dynamic_w_encoder, 
                    batch_size, 
                    (batch_size, self.model.zw_dim)
                )
                print(f"Debug: zw_init shape: {zw_init.shape}, numel: {zw_init.numel()}")
                print(f"Debug: Expected zw shape: [{batch_size}, {self.model.zw_dim}], numel: {batch_size * self.model.zw_dim}")
                
                # For categorical latent, we need to handle differently
                # Create one-hot encoding for y (we'll use dummy labels for now)
                dummy_y = torch.zeros(batch_size, self.model.num_classes).to(self.device)
                dummy_y[:, 0] = 1  # Assume class 0
                one_hot_y = dummy_y.unsqueeze(1)  # Add domain dimension
                print(f"Debug: one_hot_y shape: {one_hot_y.shape}")
                
                _ = self.model.dynamic_v_encoder(one_hot_y, None)
                zv_init = self._safe_sampling(
                    self.model.dynamic_v_encoder, 
                    batch_size, 
                    (batch_size, self.model.zv_dim)
                )
                print(f"Debug: zv_init shape: {zv_init.shape}, numel: {zv_init.numel()}")
                print(f"Debug: Expected zv shape: [{batch_size}, {self.model.zv_dim}], numel: {batch_size * self.model.zv_dim}")
                
                # Verify all shapes are correct
                expected_shapes = {
                    'zc': (batch_size, self.model.zc_dim),
                    'zw': (batch_size, self.model.zw_dim),
                    'zv': (batch_size, self.model.zv_dim)
                }
                
                actual_shapes = {
                    'zc': zc_init.shape,
                    'zw': zw_init.shape,
                    'zv': zv_init.shape
                }
                
                print(f"Debug: Shape verification:")
                for name, expected in expected_shapes.items():
                    actual = actual_shapes[name]
                    if actual == expected:
                        print(f"  ✓ {name}: {actual} == {expected}")
                    else:
                        print(f"  ✗ {name}: {actual} != {expected}")
                        raise ValueError(f"{name} has wrong shape: {actual} != {expected}")
                
            except Exception as e:
                print(f"Error in encoder sampling: {e}")
                # Create fallback initial values
                zc_init = torch.randn(batch_size, self.model.zc_dim).to(self.device)
                zw_init = torch.randn(batch_size, self.model.zw_dim).to(self.device)
                zv_init = torch.randn(batch_size, self.model.zv_dim).to(self.device)
        
        # Instead of complex MCMC, just add small noise to create multiple samples
        samples_zc = []
        samples_zw = []
        samples_zv = []
        
        for _ in range(num_samples):
            # Add small noise to create variation
            noise_scale = 0.01
            zc_sample = zc_init + torch.randn_like(zc_init) * noise_scale
            zw_sample = zw_init + torch.randn_like(zw_init) * noise_scale
            zv_sample = zv_init + torch.randn_like(zv_init) * noise_scale
            
            samples_zc.append(zc_sample)
            samples_zw.append(zw_sample)
            samples_zv.append(zv_sample)
        
        # Stack samples: [num_samples, batch_size, latent_dim]
        samples_zc = torch.stack(samples_zc)
        samples_zw = torch.stack(samples_zw)
        samples_zv = torch.stack(samples_zv)
        
        print(f"Debug: Sample shapes - zc: {samples_zc.shape}, zw: {samples_zw.shape}, zv: {samples_zv.shape}")
        
        return samples_zc, samples_zw, samples_zv
    
    def calculate_uncertainty(self, x, domain_idx, num_samples=100):
        """
        Calculate uncertainty for given input x.
        Returns prediction probabilities and uncertainty measures.
        """
        # Ensure model is in eval mode and disable gradients
        self.model.eval()
        
        with torch.no_grad():
            try:
                samples_zc, samples_zw, samples_zv = self.sample_posterior(x, domain_idx, num_samples)
                
                # Generate predictions for each sample
                predictions = []
                for i in range(num_samples):
                    try:
                        # Get the actual shapes of the samples
                        zc_sample = samples_zc[i]  # Should be [batch_size, zc_dim]
                        zw_sample = samples_zw[i]  # Should be [batch_size, zw_dim]
                        zv_sample = samples_zv[i]  # Should be [batch_size, zv_dim]
                        
                        print(f"Debug sample {i}:")
                        print(f"  zc_sample shape: {zc_sample.shape}, numel: {zc_sample.numel()}")
                        print(f"  zw_sample shape: {zw_sample.shape}, numel: {zw_sample.numel()}")
                        print(f"  zv_sample shape: {zv_sample.shape}, numel: {zv_sample.numel()}")
                        print(f"  Expected: batch_size={x.shape[0]}, zc_dim={self.model.zc_dim}, zv_dim={self.model.zv_dim}")
                        
                        # Verify shapes are correct
                        expected_zc_shape = (x.shape[0], self.model.zc_dim)
                        expected_zv_shape = (x.shape[0], self.model.zv_dim)
                        
                        if zc_sample.shape != expected_zc_shape:
                            print(f"  Error: zc_sample shape {zc_sample.shape} != expected {expected_zc_shape}")
                            raise ValueError(f"zc_sample has wrong shape: {zc_sample.shape}")
                        
                        if zv_sample.shape != expected_zv_shape:
                            print(f"  Error: zv_sample shape {zv_sample.shape} != expected {expected_zv_shape}")
                            raise ValueError(f"zv_sample has wrong shape: {zv_sample.shape}")
                        
                        print(f"  ✓ Shapes are correct")
                        
                        # Generate prediction
                        y_logit = self.model.category_cla_func(torch.cat([zv_sample, zc_sample], dim=1))
                        y_prob = F.softmax(y_logit, dim=-1)
                        predictions.append(y_prob)
                        
                    except Exception as e:
                        print(f"Error in sample {i}: {e}")
                        # Use uniform prediction as fallback
                        batch_size = x.shape[0]
                        num_classes = self.model.num_classes
                        fallback_prob = torch.ones(batch_size, num_classes).to(self.device) / num_classes
                        predictions.append(fallback_prob)
                
                if not predictions:
                    raise ValueError("No valid predictions generated")
                    
                predictions = torch.stack(predictions)  # [num_samples, batch_size, num_classes]
                
                # Calculate uncertainty measures
                mean_pred = torch.mean(predictions, dim=0)
                var_pred = torch.var(predictions, dim=0)
                entropy = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8), dim=-1)
                
                # Predictive entropy (total uncertainty)
                pred_entropy = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8), dim=-1)
                
                # Expected entropy (aleatoric uncertainty)
                exp_entropy = torch.mean(-torch.sum(predictions * torch.log(predictions + 1e-8), dim=-1), dim=0)
                
                # Mutual information (epistemic uncertainty)
                mutual_info = pred_entropy - exp_entropy
                
                return {
                    'mean_prediction': mean_pred,
                    'prediction_variance': var_pred,
                    'entropy': entropy,
                    'predictive_entropy': pred_entropy,
                    'expected_entropy': exp_entropy,
                    'mutual_information': mutual_info,
                    'samples': predictions
                }
            except Exception as e:
                print(f"Error in calculate_uncertainty: {e}")
                # Return default uncertainty values
                batch_size = x.shape[0]
                num_classes = self.model.num_classes
                
                default_pred = torch.ones(batch_size, num_classes).to(self.device) / num_classes
                default_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
                
                return {
                    'mean_prediction': default_pred,
                    'prediction_variance': torch.zeros_like(default_pred),
                    'entropy': torch.full((batch_size,), default_entropy, device=self.device),
                    'predictive_entropy': torch.full((batch_size,), default_entropy, device=self.device),
                    'expected_entropy': torch.full((batch_size,), default_entropy, device=self.device),
                    'mutual_information': torch.zeros(batch_size, device=self.device),
                    'samples': default_pred.unsqueeze(0).repeat(num_samples, 1, 1)
                }
    
    def analyze_real_vs_wrong_data(self, num_samples=1000):
        """
        Compare uncertainty on real data vs wrong/perturbed data.
        """
        results = {
            'real_data': {'uncertainties': [], 'predictions': []},
            'wrong_data': {'uncertainties': [], 'predictions': []}
        }
        
        # Get data from different domains
        source_domains = list(range(15))  # Source domains
        target_domains = list(range(20, 30))  # Target domains
        
        print("Analyzing uncertainty on real data...")
        real_data_success = 0
        real_data_failed = 0
        
        for domain_idx in tqdm(source_domains):
            try:
                dataset = self.dataset.datasets[domain_idx]
                dataloader = FastDataLoader(dataset, batch_size=32, num_workers=0)
                
                domain_success = 0
                for batch_idx, (x, y) in enumerate(dataloader):
                    if batch_idx >= 10:  # Limit samples per domain
                        break
                        
                    x, y = x.to(self.device), y.to(self.device)
                    print(f"Debug: Processing domain {domain_idx}, batch {batch_idx}, x shape: {x.shape}, y shape: {y.shape}")
                    
                    try:
                        # Analyze real data
                        uncertainty_real = self.calculate_uncertainty(x, domain_idx, num_samples=50)
                        
                        results['real_data']['uncertainties'].append({
                            'domain': domain_idx,
                            'predictive_entropy': self._safe_to_numpy(uncertainty_real['predictive_entropy']),
                            'mutual_information': self._safe_to_numpy(uncertainty_real['mutual_information']),
                            'prediction_variance': self._safe_to_numpy(uncertainty_real['prediction_variance'])
                        })
                        results['real_data']['predictions'].append({
                            'domain': domain_idx,
                            'mean_prediction': self._safe_to_numpy(uncertainty_real['mean_prediction']),
                            'true_labels': self._safe_to_numpy(y)
                        })
                        
                        domain_success += 1
                        real_data_success += 1
                        
                    except Exception as e:
                        print(f"  Error in uncertainty calculation for domain {domain_idx}, batch {batch_idx}: {e}")
                        real_data_failed += 1
                        continue
                
                if domain_success > 0:
                    print(f"✓ Domain {domain_idx}: {domain_success} batches processed successfully")
                else:
                    print(f"✗ Domain {domain_idx}: All batches failed")
                    
            except Exception as e:
                print(f"Error processing domain {domain_idx}: {e}")
                real_data_failed += 1
                continue
        
        print(f"\nReal data analysis summary:")
        print(f"  Successful batches: {real_data_success}")
        print(f"  Failed batches: {real_data_failed}")
        print(f"  Total results collected: {len(results['real_data']['uncertainties'])}")
        
        print("\nAnalyzing uncertainty on wrong/perturbed data...")
        wrong_data_success = 0
        wrong_data_failed = 0
        
        for domain_idx in tqdm(source_domains):
            try:
                dataset = self.dataset.datasets[domain_idx]
                dataloader = FastDataLoader(dataset, batch_size=32, num_workers=0)
                
                domain_success = 0
                for batch_idx, (x, y) in enumerate(dataloader):
                    if batch_idx >= 10:  # Limit samples per domain
                        break
                        
                    x, y = x.to(self.device), y.to(self.device)
                    
                    # Create wrong data by adding noise
                    x_wrong = x + torch.randn_like(x) * 0.5
                    print(f"Debug: Processing wrong data for domain {domain_idx}, batch {batch_idx}, x_wrong shape: {x_wrong.shape}")
                    
                    try:
                        # Analyze wrong data
                        uncertainty_wrong = self.calculate_uncertainty(x_wrong, domain_idx, num_samples=50)
                        
                        results['wrong_data']['uncertainties'].append({
                            'domain': domain_idx,
                            'predictive_entropy': self._safe_to_numpy(uncertainty_wrong['predictive_entropy']),
                            'mutual_information': self._safe_to_numpy(uncertainty_wrong['mutual_information']),
                            'prediction_variance': self._safe_to_numpy(uncertainty_wrong['prediction_variance'])
                        })
                        results['wrong_data']['predictions'].append({
                            'domain': domain_idx,
                            'mean_prediction': self._safe_to_numpy(uncertainty_wrong['mean_prediction']),
                            'true_labels': self._safe_to_numpy(y)
                        })
                        
                        domain_success += 1
                        wrong_data_success += 1
                        
                    except Exception as e:
                        print(f"  Error in uncertainty calculation for wrong data, domain {domain_idx}, batch {batch_idx}: {e}")
                        wrong_data_failed += 1
                        continue
                
                if domain_success > 0:
                    print(f"✓ Wrong data domain {domain_idx}: {domain_success} batches processed successfully")
                else:
                    print(f"✗ Wrong data domain {domain_idx}: All batches failed")
                    
            except Exception as e:
                print(f"Error processing wrong data for domain {domain_idx}: {e}")
                wrong_data_failed += 1
                continue
        
        print(f"\nWrong data analysis summary:")
        print(f"  Successful batches: {wrong_data_success}")
        print(f"  Failed batches: {wrong_data_failed}")
        print(f"  Total results collected: {len(results['wrong_data']['uncertainties'])}")
        
        print(f"\nOverall analysis summary:")
        print(f"  Real data results: {len(results['real_data']['uncertainties'])}")
        print(f"  Wrong data results: {len(results['wrong_data']['uncertainties'])}")
        
        if len(results['real_data']['uncertainties']) == 0 and len(results['wrong_data']['uncertainties']) == 0:
            print("❌ CRITICAL: No results collected from either real or wrong data!")
            print("   This indicates a fundamental issue with the analysis pipeline.")
            print("   Check the error messages above for specific issues.")
        
        return results
    
    def visualize_uncertainty(self, results, save_path='uncertainty_analysis'):
        """
        Create visualizations for uncertainty analysis.
        """
        os.makedirs(save_path, exist_ok=True)
        
        # Check if we have any results to visualize
        if not results['real_data']['uncertainties'] and not results['wrong_data']['uncertainties']:
            print("⚠️  WARNING: No uncertainty data available for visualization!")
            print("   This suggests the analysis failed to process any data.")
            print("   Check the analysis logs for errors.")
            
            # Create a simple error report
            error_report = {
                'error': 'No uncertainty data available',
                'real_data_count': len(results['real_data']['uncertainties']),
                'wrong_data_count': len(results['wrong_data']['uncertainties']),
                'suggestion': 'Run the analysis again and check for errors'
            }
            
            with open(os.path.join(save_path, 'error_report.json'), 'w') as f:
                json.dump(error_report, f, indent=2)
            
            print(f"Error report saved to {save_path}/error_report.json")
            return
        
        # Check if we have real data
        if not results['real_data']['uncertainties']:
            print("⚠️  WARNING: No real data uncertainty results available!")
            print("   Only wrong data results will be visualized.")
            real_entropy = np.array([])
            real_mi = np.array([])
            real_domains = []
            real_entropy_per_domain = np.array([])
            real_mi_per_domain = np.array([])
        else:
            # Extract real data
            real_entropy = np.concatenate([r['predictive_entropy'] for r in results['real_data']['uncertainties']])
            real_mi = np.concatenate([r['mutual_information'] for r in results['real_data']['uncertainties']])
            real_domains = [r['domain'] for r in results['real_data']['uncertainties']]
            
            # Create domain-level statistics for scatter plots
            real_entropy_per_domain = []
            real_mi_per_domain = []
            for r in results['real_data']['uncertainties']:
                try:
                    real_entropy_per_domain.append(np.mean(r['predictive_entropy']))
                    real_mi_per_domain.append(np.mean(r['mutual_information']))
                except Exception as e:
                    print(f"  Warning: Error processing domain {r.get('domain', 'unknown')}: {e}")
                    # Use fallback values
                    real_entropy_per_domain.append(0.0)
                    real_mi_per_domain.append(0.0)
            
            real_entropy_per_domain = np.array(real_entropy_per_domain)
            real_mi_per_domain = np.array(real_mi_per_domain)
            
            print(f"✓ Real data: {len(results['real_data']['uncertainties'])} domains, {len(real_entropy)} samples")
            print(f"  Domain-level stats: {len(real_entropy_per_domain)} domains")
        
        # Check if we have wrong data
        if not results['wrong_data']['uncertainties']:
            print("⚠️  WARNING: No wrong data uncertainty results available!")
            print("   Only real data results will be visualized.")
            wrong_entropy = np.array([])
            wrong_mi = np.array([])
            wrong_domains = []
            wrong_entropy_per_domain = np.array([])
            wrong_mi_per_domain = np.array([])
        else:
            # Extract wrong data
            wrong_entropy = np.concatenate([r['predictive_entropy'] for r in results['wrong_data']['uncertainties']])
            wrong_mi = np.concatenate([r['mutual_information'] for r in results['wrong_data']['uncertainties']])
            wrong_domains = [r['domain'] for r in results['wrong_data']['uncertainties']]
            
            # Create domain-level statistics for scatter plots
            wrong_entropy_per_domain = []
            wrong_mi_per_domain = []
            for r in results['wrong_data']['uncertainties']:
                try:
                    wrong_entropy_per_domain.append(np.mean(r['predictive_entropy']))
                    wrong_mi_per_domain.append(np.mean(r['mutual_information']))
                except Exception as e:
                    print(f"  Warning: Error processing wrong data domain {r.get('domain', 'unknown')}: {e}")
                    # Use fallback values
                    wrong_entropy_per_domain.append(0.0)
                    wrong_mi_per_domain.append(0.0)
            
            wrong_entropy_per_domain = np.array(wrong_entropy_per_domain)
            wrong_mi_per_domain = np.array(wrong_mi_per_domain)
            
            print(f"✓ Wrong data: {len(results['wrong_data']['uncertainties'])} domains, {len(wrong_entropy)} samples")
            print(f"  Domain-level stats: {len(wrong_entropy_per_domain)} domains")
        
        # Only proceed if we have at least some data
        if len(real_entropy) == 0 and len(wrong_entropy) == 0:
            print("❌ ERROR: No data available for visualization!")
            return
        
        print(f"\nVisualizing uncertainty results:")
        print(f"  Real data: {len(real_entropy)} samples from {len(real_domains)} domains")
        print(f"  Wrong data: {len(wrong_entropy)} samples from {len(wrong_domains)} domains")
        
        # Debug: Show data structure
        if len(real_domains) > 0:
            print(f"  Real data structure:")
            print(f"    Domains: {real_domains[:5]}{'...' if len(real_domains) > 5 else ''}")
            print(f"    Entropy per domain: {len(real_entropy_per_domain)} values")
            print(f"    Sample entropy: {len(real_entropy)} values")
        
        if len(wrong_domains) > 0:
            print(f"  Wrong data structure:")
            print(f"    Domains: {wrong_domains[:5]}{'...' if len(wrong_domains) > 5 else ''}")
            print(f"    Entropy per domain: {len(wrong_entropy_per_domain)} values")
            print(f"    Sample entropy: {len(wrong_entropy)} values")
        
        # Plot 1: Distribution of predictive entropy (only if we have data)
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        if len(real_entropy) > 0:
            plt.hist(real_entropy, bins=30, alpha=0.7, label='Real Data', density=True)
        if len(wrong_entropy) > 0:
            plt.hist(wrong_entropy, bins=30, alpha=0.7, label='Wrong Data', density=True)
        plt.xlabel('Predictive Entropy')
        plt.ylabel('Density')
        plt.title('Distribution of Predictive Entropy')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        if len(real_mi) > 0:
            plt.hist(real_mi, bins=30, alpha=0.7, label='Real Data', density=True)
        if len(wrong_mi) > 0:
            plt.hist(wrong_mi, bins=30, alpha=0.7, label='Wrong Data', density=True)
        plt.xlabel('Mutual Information (Epistemic Uncertainty)')
        plt.ylabel('Density')
        plt.title('Distribution of Epistemic Uncertainty')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, 'uncertainty_distributions.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # Plot 2: Uncertainty vs domain (only if we have domain data)
        if (len(real_domains) > 0 and len(real_entropy_per_domain) > 0) or (len(wrong_domains) > 0 and len(wrong_entropy_per_domain) > 0):
            plt.figure(figsize=(12, 5))
            
            plt.subplot(1, 2, 1)
            if len(real_domains) > 0:
                print(f"Domain plot - Real data: {len(real_domains)} domains, {len(real_entropy_per_domain)} domain-level entropies")
                plt.scatter(real_domains, real_entropy_per_domain, alpha=0.6, label='Real Data', s=50)
            if len(wrong_domains) > 0:
                plt.scatter(wrong_domains, wrong_entropy_per_domain, alpha=0.6, label='Wrong Data', s=50)
            plt.xlabel('Domain Index')
            plt.ylabel('Mean Predictive Entropy per Domain')
            plt.title('Uncertainty vs Domain')
            plt.legend()
            
            plt.subplot(1, 2, 2)
            if len(real_domains) > 0:
                plt.scatter(real_domains, real_mi_per_domain, alpha=0.6, label='Real Data', s=50)
            if len(wrong_domains) > 0:
                plt.scatter(wrong_domains, wrong_mi_per_domain, alpha=0.6, label='Wrong Data', s=50)
            plt.xlabel('Domain Index')
            plt.ylabel('Mean Mutual Information per Domain')
            plt.title('Epistemic Uncertainty vs Domain')
            plt.legend()
            
            plt.tight_layout()
            plt.savefig(os.path.join(save_path, 'uncertainty_vs_domain.png'), dpi=300, bbox_inches='tight')
            plt.close()
        
        # Save numerical results
        summary_stats = {}
        
        if len(real_entropy) > 0:
            summary_stats['real_data'] = {
                'mean_predictive_entropy': float(np.mean(real_entropy)),
                'std_predictive_entropy': float(np.std(real_entropy)),
                'mean_mutual_information': float(np.mean(real_mi)),
                'std_mutual_information': float(np.std(real_mi)),
                'sample_count': len(real_entropy),
                'domain_count': len(real_domains),
                'mean_entropy_per_domain': float(np.mean(real_entropy_per_domain)),
                'std_entropy_per_domain': float(np.std(real_entropy_per_domain)),
                'mean_mi_per_domain': float(np.mean(real_mi_per_domain)),
                'std_mi_per_domain': float(np.std(real_mi_per_domain))
            }
        
        if len(wrong_entropy) > 0:
            summary_stats['wrong_data'] = {
                'mean_predictive_entropy': float(np.mean(wrong_entropy)),
                'std_predictive_entropy': float(np.std(wrong_entropy)),
                'mean_mutual_information': float(np.mean(wrong_mi)),
                'std_mutual_information': float(np.std(wrong_mi)),
                'sample_count': len(wrong_entropy),
                'domain_count': len(wrong_domains),
                'mean_entropy_per_domain': float(np.mean(wrong_entropy_per_domain)),
                'std_entropy_per_domain': float(np.std(wrong_entropy_per_domain)),
                'mean_mi_per_domain': float(np.mean(wrong_mi_per_domain)),
                'std_mi_per_domain': float(np.std(wrong_mi_per_domain))
            }
        
        # Add overall summary
        summary_stats['overall'] = {
            'total_real_samples': len(real_entropy),
            'total_wrong_samples': len(wrong_entropy),
            'analysis_success': len(real_entropy) > 0 or len(wrong_entropy) > 0
        }
        
        with open(os.path.join(save_path, 'uncertainty_summary.json'), 'w') as f:
            json.dump(summary_stats, f, indent=2)
        
        print(f"\nUncertainty analysis results saved to {save_path}")
        
        if len(real_entropy) > 0:
            print(f"Real data - Mean predictive entropy: {summary_stats['real_data']['mean_predictive_entropy']:.4f}")
        
        if len(wrong_entropy) > 0:
            print(f"Wrong data - Mean predictive entropy: {summary_stats['wrong_data']['mean_predictive_entropy']:.4f}")
        
        if len(real_entropy) > 0 and len(wrong_entropy) > 0:
            uncertainty_increase = summary_stats['wrong_data']['mean_predictive_entropy'] - summary_stats['real_data']['mean_predictive_entropy']
            print(f"Uncertainty increase: {uncertainty_increase:.4f}")
            
            if uncertainty_increase > 0:
                print("✓ SUCCESS: Wrong data shows higher uncertainty than real data")
            else:
                print("⚠️  WARNING: Wrong data shows similar or lower uncertainty than real data")
        else:
            print("⚠️  Cannot compare uncertainty: missing data from one or both categories")


def main():
    parser = argparse.ArgumentParser(description='Uncertainty Analysis for BayesShift')
    parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--data_path', type=str, required=True, help='Path to dataset')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--num_samples', type=int, default=100, help='Number of MCMC samples')
    parser.add_argument('--save_path', type=str, default='uncertainty_analysis', help='Path to save results')
    
    args = parser.parse_args()
    
    # Initialize analyzer
    analyzer = UncertaintyAnalyzer(args.model_path, args.data_path, args.device)
    
    # Run uncertainty analysis
    print("Starting uncertainty analysis...")
    results = analyzer.analyze_real_vs_wrong_data(num_samples=args.num_samples)
    
    # Visualize results
    analyzer.visualize_uncertainty(results, args.save_path)
    
    print("Uncertainty analysis completed!")


if __name__ == '__main__':
    main()
