import os
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils.get_dataloader import get_dataloader
from models.get_model import get_model
from utils.get_args import get_args
import numpy as np
import yaml
import argparse

def plot_vanilla_pc_free_energy(output_dict, log_dir, prefix=""):
    # 3D plot of vanilla_pc_free_energy
    x = np.arange(len(output_dict['vanilla_pc_free_energy_list'][0]))
    y = np.arange(len(output_dict['vanilla_pc_free_energy_list']))
    X, Y = np.meshgrid(x, y)
    Z = np.array(output_dict['vanilla_pc_free_energy_list'])
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=20, azim=-110)  
    ax.plot_surface(X, Y, Z, cmap='viridis')
    ax.set_xlabel('layer')
    ax.set_ylabel('t')
    ax.set_zlabel('vanilla_pc_free_energy')
    plt.title(f'{prefix}Vanilla PC Free Energy 3D Plot')
    plt.savefig(os.path.join(log_dir, f'{prefix}vanilla_pc_free_energy_list.png'))
    plt.close()
    
    # Plot sum of vanilla_pc_free_energy for each t vs t
    vanilla_pc_free_energy_array = np.array(output_dict['vanilla_pc_free_energy_list'])
    t_values = np.arange(vanilla_pc_free_energy_array.shape[0])
    vanilla_pc_free_energy_sum_per_t = np.sum(vanilla_pc_free_energy_array, axis=1)  # Sum across layers for each t
    
    plt.figure(figsize=(10, 6))
    plt.plot(t_values, vanilla_pc_free_energy_sum_per_t, 'b-', linewidth=2)
    plt.xlabel('t')
    plt.ylabel('vanilla_pc_free_energy_sum')
    plt.title(f'{prefix}Sum of Vanilla PC Free Energy vs t')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(log_dir, f'{prefix}vanilla_pc_free_energy_sum_vs_t.png'))
    plt.close()

def plot_meta_pc_free_energy(output_dict, log_dir, prefix=""):
    """Plot meta PC free energy (currently zeros)"""
    # 3D plot of meta_pc_free_energy
    x = np.arange(len(output_dict['meta_pc_free_energy_list'][0]))
    y = np.arange(len(output_dict['meta_pc_free_energy_list']))
    X, Y = np.meshgrid(x, y)
    Z = np.array(output_dict['meta_pc_free_energy_list'])
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=20, azim=-110)  
    ax.plot_surface(X, Y, Z, cmap='plasma')
    ax.set_xlabel('layer')
    ax.set_ylabel('t')
    ax.set_zlabel('meta_pc_free_energy')
    plt.title(f'{prefix}Meta PC Free Energy 3D Plot')
    plt.savefig(os.path.join(log_dir, f'{prefix}meta_pc_free_energy_list.png'))
    plt.close()
    
    # Plot sum of meta_pc_free_energy for each t vs t
    meta_pc_free_energy_array = np.array(output_dict['meta_pc_free_energy_list'])
    t_values = np.arange(meta_pc_free_energy_array.shape[0])
    meta_pc_free_energy_sum_per_t = np.sum(meta_pc_free_energy_array, axis=1)  # Sum across layers for each t
    
    plt.figure(figsize=(10, 6))
    plt.plot(t_values, meta_pc_free_energy_sum_per_t, 'r-', linewidth=2)
    plt.xlabel('t')
    plt.ylabel('meta_pc_free_energy_sum')
    plt.title(f'{prefix}Sum of Meta PC Free Energy vs t')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(log_dir, f'{prefix}meta_pc_free_energy_sum_vs_t.png'))
    plt.close()

def run_model_on_data(model, data_loader, args, num_batches=None):
    model.eval()
    all_outputs = []
    all_per_sample_metrics = []  # To collect per-sample metrics
    total_samples = 0
    correct = 0
    
    with torch.no_grad():
        pbar = tqdm(data_loader)
        for batch_idx, (x, y) in enumerate(pbar):
            if num_batches is not None and batch_idx >= num_batches:
                break
                
            x, y = x.to(args.device), y.to(args.device)
            
            # Create a dummy optimizer with lr=0 (no learning should happen)
            dummy_optimizer = torch.optim.AdamW(model.parameters(), lr=0.0)
            
            # Run model forward pass
            output_dict = model(x, y, dummy_optimizer, return_per_sample_metric=True)
            
            # Calculate accuracy
            preds = output_dict['pred'].argmax(dim=1)
            correct += (preds == y).sum().item()
            total_samples += x.size(0)
            
            # Collect per-sample metrics (L, batch_size)
            all_per_sample_metrics.append(output_dict['vanilla_pc_free_energy_per_sample'])
            
            all_outputs.append({
                'vanilla_pc_free_energy_list': output_dict['vanilla_pc_free_energy_list'],
                'meta_pc_free_energy_list': output_dict['meta_pc_free_energy_list'],
                'pred': output_dict['pred'],
                'targets': y,
                'batch_idx': batch_idx
            })
            
            pbar.set_description(f"Processing batch {batch_idx+1}, Acc: {correct/total_samples:.4f}")
    
    # Concatenate all per-sample metrics to get (L, dataset_len)
    if all_per_sample_metrics:
        per_sample_metrics_full = np.concatenate(all_per_sample_metrics, axis=1)  # Concat along batch dimension
        print(f"Per-sample metrics shape: {per_sample_metrics_full.shape}")  # Should be (L, dataset_len)
    else:
        per_sample_metrics_full = None
    
    accuracy = correct / total_samples
    print(f"Final accuracy: {accuracy:.4f}")
    
    return all_outputs, accuracy, per_sample_metrics_full

def main():
    parser = argparse.ArgumentParser(description='Plot PCN model results')
    parser.add_argument('--log_dir', type=str, required=True,
                        help='Directory containing the saved model and args.yaml')
    parser.add_argument('--dataset_type', type=str, default='test', choices=['train', 'test'],
                        help='Dataset to use for plotting')
    parser.add_argument('--num_batches', type=int, default=None,
                        help='Number of batches to process (None for all)')
    parser.add_argument('--output_dir', type=str, default=None,
                        help='Directory to save plots (default: same as log_dir)')
    
    plot_args = parser.parse_args()
    
    # Load saved arguments
    args_path = os.path.join(plot_args.log_dir, 'args.yaml')
    if not os.path.exists(args_path):
        raise FileNotFoundError(f"args.yaml not found in {plot_args.log_dir}")
    
    with open(args_path, 'r') as f:
        try:
            saved_args_dict = yaml.safe_load(f)
        except yaml.constructor.ConstructorError:
            # Handle legacy YAML files with Python objects
            with open(args_path, 'r') as f:
                content = f.read()
                # Remove problematic Python tuple tags
                content = content.replace('!!python/tuple', '')
                saved_args_dict = yaml.safe_load(content)
                # Convert img_shape list back to tuple if needed
                if 'img_shape' in saved_args_dict and isinstance(saved_args_dict['img_shape'], list):
                    saved_args_dict['img_shape'] = tuple(saved_args_dict['img_shape'])
    
    # Create empty namespace and fill its dict
    args = argparse.Namespace()
    args.__dict__.update(saved_args_dict)
    
    # Set output directory
    if plot_args.output_dir is None:
        output_dir = plot_args.log_dir
    else:
        output_dir = plot_args.output_dir
        os.makedirs(output_dir, exist_ok=True)
    
    print(f"Loading model from: {plot_args.log_dir}")
    print(f"Using dataset: {plot_args.dataset_type}")
    print(f"Output directory: {output_dir}")
    print(f"Model configuration:")
    print(f"  - update_latent_rule: {args.update_latent_rule}")
    print(f"  - energy_option: {args.energy_option}")
    print(f"  - update_param_rule: {args.update_param_rule}")
    print(f"  - T: {args.T}")
    print(f"  - eta: {args.eta}")
    
    # Load data
    dataloader_dict = get_dataloader(args)
    if plot_args.dataset_type == 'train':
        data_loader = dataloader_dict['train_loader']
    else:
        data_loader = dataloader_dict['test_loader']
    
    # Load model
    model = get_model(args).to(args.device)
    model_path = os.path.join(plot_args.log_dir, 'model.pth')
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"model.pth not found in {plot_args.log_dir}")
    
    model.load_state_dict(torch.load(model_path, map_location=args.device))
    print(f"Model loaded from: {model_path}")
    
    # Run model on data
    print(f"Running model on {plot_args.dataset_type} data...")
    all_outputs, accuracy, per_sample_metrics_full = run_model_on_data(model, data_loader, args, plot_args.num_batches)
    
    # Create summary plots using the last batch
    if all_outputs:
        last_output = all_outputs[-1]
        prefix = f"plotting_{plot_args.dataset_type}_"
        
        print("Creating plots...")
        plot_vanilla_pc_free_energy({
            'vanilla_pc_free_energy_list': last_output['vanilla_pc_free_energy_list']
        }, output_dir, prefix)
        
        plot_meta_pc_free_energy({
            'meta_pc_free_energy_list': last_output['meta_pc_free_energy_list']
        }, output_dir, prefix)
        
        # Create average plot across all batches
        if len(all_outputs) > 1:
            print("Creating average plots across all batches...")
            
            # Average vanilla_pc_free_energy across batches
            all_vanilla_pc_free_energy = [output['vanilla_pc_free_energy_list'] for output in all_outputs]
            avg_vanilla_pc_free_energy = np.mean(all_vanilla_pc_free_energy, axis=0)
            
            plot_vanilla_pc_free_energy({
                'vanilla_pc_free_energy_list': avg_vanilla_pc_free_energy
            }, output_dir, f"{prefix}avg_")
            
            # Average meta_pc_free_energy across batches
            all_meta_pc_free_energy = [output['meta_pc_free_energy_list'] for output in all_outputs]
            avg_meta_pc_free_energy = np.mean(all_meta_pc_free_energy, axis=0)
            
            plot_meta_pc_free_energy({
                'meta_pc_free_energy_list': avg_meta_pc_free_energy
            }, output_dir, f"{prefix}avg_")
        
        # Create per-sample metric plot 
        d = per_sample_metrics_full[:-1, :]
        d_std = np.std(d, axis=0, ddof=1)
        d_mean = np.mean(d, axis=0)
        d_coef_var = d_std / d_mean
        np.set_printoptions(precision=5)
        # def softmax and entropy
        def softmax(x, axis=0):
            return np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True)
        def entropy(p, base=2, axis=0):
            return -np.sum(p * np.log(p) / np.log(base), axis=axis)
        p = softmax(d, axis=0)
        ent = entropy(p, base=2, axis=0)
        per_sample_stats = {
            'd_std_mean': np.mean(d_std),
            'd_std_std': np.std(d_std),
            'd_coef_mean': np.mean(d_coef_var),
            'd_coef_std': np.std(d_coef_var),
            'd_ent_mean': np.mean(ent), 
            'd_ent_std': np.std(ent), 
        }
        
        
        # Save summary statistics
        summary_path = os.path.join(output_dir, f'{prefix}summary.txt')
        with open(summary_path, 'w') as f:
            f.write(f"Dataset: {plot_args.dataset_type}\n")
            f.write(f"Accuracy: {accuracy:.4f}\n")
            f.write(f"Number of batches processed: {len(all_outputs)}\n")
            f.write(f"Model configuration:\n")
            f.write(f"  - update_latent_rule: {args.update_latent_rule}\n")
            f.write(f"  - energy_option: {args.energy_option}\n")
            f.write(f"  - update_param_rule: {args.update_param_rule}\n")
            f.write(f"  - T: {args.T}\n")
            f.write(f"  - eta: {args.eta}\n")
            f.write(f"  - lr: {args.lr}\n")
            f.write(f"  - batch_size: {args.batch_size}\n")
            f.write(f"  - backbone: {args.backbone}\n")
            f.write(f"  - dataset: {args.dataset}\n")
            f.write(f"Per-sample statistics:\n")
            for key, value in per_sample_stats.items():
                f.write(f"  - {key}: {value:.4f}\n")
        
        print(f"Summary saved to: {summary_path}")
        print(f"Plots saved to: {output_dir}")
        print("Done!")
    
    else:
        print("No data processed!")

if __name__ == '__main__':
    main()