import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import subprocess

def plot_training_state(model, last_batch_data, last_batch_labels, epoch, multi_gpu=False, 
                       save_dir="./03_results/plots/temp_latent_plots/", device='cuda', verbose=False):
    """
    Plot current training state: weight matrices and embeddings
    
    Args:
        model: The trained model
        last_batch_data: Last training batch data (list of tensors)
        last_batch_labels: Labels for the last batch
        epoch: Current epoch number
        multi_gpu: Whether using DataParallel
        save_dir: Directory to save plots
        device: Device to use for computations
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Get model reference (handle DataParallel)
    model_ref = model.module if multi_gpu else model
    
    # Create figure with 3 rows, 3 columns
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    fig.suptitle(f'Training State - Epoch {epoch}', fontsize=16)
    
    # Row 1: Weight matrices of adaptive layers
    for i, layer in enumerate(model_ref.adaptive_layers[:3]):  # Ensure we only plot 3
        ax = axes[0, i]
        
        # Get weight matrix from factorized form
        try:
            if hasattr(layer, 'U') and hasattr(layer, 'V'):
                # For rank-reduced layers: weight = U @ V.T (active dimensions only)
                U_active = layer.U[:, :layer.active_dims].detach().cpu()  # [out_features, active_dims]
                V_active = layer.V[:layer.active_dims, :].detach().cpu()  # [active_dims, in_features]
                #weight_matrix = (U_active @ V_active.T).detach().cpu().numpy()  # [out_features, in_features]
                weight_matrix = torch.matmul(U_active, V_active).numpy()
                
            elif hasattr(layer, 'get_weights'):
                # If layer has a method to get weights
                weight_matrix = layer.get_weights().detach().cpu().numpy()
                
            else:
                # Fallback: try standard weight access methods
                weight_matrix = None
                
                # Try different standard weight access patterns
                if hasattr(layer, 'layer') and hasattr(layer.layer, 'weight'):
                    weight_matrix = layer.layer.weight.detach().cpu().numpy()
                elif hasattr(layer, 'linear') and hasattr(layer.linear, 'weight'):
                    weight_matrix = layer.linear.weight.detach().cpu().numpy()
                elif hasattr(layer, 'weight'):
                    weight_matrix = layer.weight.detach().cpu().numpy()
                else:
                    # Search through parameters
                    for name, param in layer.named_parameters():
                        if 'weight' in name and param.dim() == 2:
                            weight_matrix = param.detach().cpu().numpy()
                            break
                
                if weight_matrix is None:
                    # Create placeholder
                    weight_matrix = np.random.randn(layer.out_features, layer.in_features) * 0.1
            
        except Exception as e:
            print(f"Error accessing layer {i} weights: {e}")
            # Fallback for any errors
            weight_matrix = np.random.randn(50, 50) * 0.1
            ax.text(0.5, 0.5, f'Error accessing\nweight matrix\nfor layer {i}\n{str(e)[:50]}...', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=8)
        
        # Plot heatmap with diverging colormap centered at zero
        vmax = np.abs(weight_matrix).max()
        vmin = -vmax
        im = ax.imshow(weight_matrix, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
        ax.set_title(f'Layer {i} Weights\n(Rank: {layer.active_dims})')
        ax.set_xlabel('Input Dimensions')
        ax.set_ylabel('Output Dimensions')
        plt.colorbar(im, ax=ax)
    
    # Row 2: Embeddings as heatmaps (samples × latent dimensions)
    model_ref.eval()
    with torch.no_grad():
        # Move data to device
        batch_data = [x.to(device) for x in last_batch_data]
        
        # Get embeddings
        encoded_shared, encoded_specific = model_ref.encode(batch_data)
        embeddings = [encoded_shared] + list(encoded_specific)
        
        # Sort by labels (first by label 1, then label 2, etc.)
        if last_batch_labels is not None:
            # Convert to numpy for sorting
            labels_np = last_batch_labels.cpu().numpy()
            
            # Create sorting indices (lexicographic sort)
            if labels_np.ndim > 1:
                sort_indices = np.lexsort([labels_np[:, i] for i in range(labels_np.shape[1]-1, -1, -1)])
            else:
                sort_indices = np.argsort(labels_np)
        else:
            sort_indices = np.arange(len(batch_data[0]))
        
        # Plot embeddings as heatmaps for each space
        for i, embedding in enumerate(embeddings[:3]):  # Ensure we only plot 3
            ax = axes[1, i]
            
            # Get sorted embeddings (samples × latent dimensions)
            emb_np = embedding.cpu().numpy()[sort_indices]
            
            # Plot as heatmap with diverging colormap centered at zero
            vmax = np.abs(emb_np).max()
            vmin = -vmax
            im = ax.imshow(emb_np, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
            
            space_name = 'Shared' if i == 0 else f'Specific {i}'
            ax.set_title(f'{space_name} Space\n(Rank: {model_ref.adaptive_layers[i].active_dims})')
            ax.set_xlabel('Latent Dimensions')
            ax.set_ylabel('Samples (sorted by labels)')
            plt.colorbar(im, ax=ax)
    
    # Row 3: First layer weight matrices from decoder modules (2 columns for 2 decoders)
    for i in range(min(2, len(model_ref.decoders))):  # Only plot first 2 decoders
        ax = axes[2, i]
        
        try:
            # Get the first layer from the decoder module
            decoder_module = model_ref.decoders[i]
            first_layer = None
            
            # Find the first linear layer in the decoder
            for layer in decoder_module:
                if hasattr(layer, 'weight') and layer.weight.dim() == 2:
                    first_layer = layer
                    break
            
            if first_layer is not None:
                weight_matrix = first_layer.weight.detach().cpu().numpy()
                
                # Plot heatmap with diverging colormap centered at zero
                vmax = np.abs(weight_matrix).max()
                vmin = -vmax
                im = ax.imshow(weight_matrix, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
                ax.set_title(f'Decoder {i+1} First Layer\n({weight_matrix.shape[0]}×{weight_matrix.shape[1]})')
                ax.set_xlabel('Input Dimensions')
                ax.set_ylabel('Output Dimensions')
                plt.colorbar(im, ax=ax)
            else:
                # No linear layer found
                ax.text(0.5, 0.5, f'No linear layer\nfound in\nDecoder {i+1}', 
                       ha='center', va='center', transform=ax.transAxes, fontsize=10)
                ax.set_title(f'Decoder {i+1} - No Linear Layer')
                
        except Exception as e:
            print(f"Error accessing decoder {i} first layer weights: {e}")
            ax.text(0.5, 0.5, f'Error accessing\nDecoder {i+1}\nweight matrix\n{str(e)[:30]}...', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=8)
            ax.set_title(f'Decoder {i+1} - Error')
    
    # Hide the third column in row 3 since we only have 2 decoders
    if len(model_ref.decoders) < 3:
        axes[2, 2].set_visible(False)
    
    plt.tight_layout()
    
    # Save plot
    save_path = os.path.join(save_dir, f'training_state_epoch_{epoch:04d}.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

    #if verbose:
    #    print(f"Saved training state plot to {save_path}")

def create_training_movie(plot_dir="./03_results/plots/temp_latent_plots/", 
                         output_path=None,
                         fps=2):
    """
    Create an animated GIF from the training state plots
    
    Args:
        plot_dir: Directory containing the plots
        output_path: Output path for the GIF animation (if None, will be auto-generated)
        fps: Frames per second
    """
    import glob
    from matplotlib import animation
    import matplotlib.image as mpimg
    import shutil
    
    # Get all plot files and sort them by epoch number
    plot_files = glob.glob(os.path.join(plot_dir, 'training_state_epoch_*.png'))
    if not plot_files:
        print(f"No training state plots found in {plot_dir}")
        return
    
    # Sort by epoch number
    plot_files.sort(key=lambda x: int(x.split('epoch_')[-1].split('.')[0]))
    print(f"Found {len(plot_files)} training state plots")
    
    # Auto-generate output path if not provided
    if output_path is None:
        # Extract model name from plot directory
        model_name = os.path.basename(plot_dir.rstrip('/'))
        if model_name == 'temp_latent_plots':
            # Try to get model name from parent directory structure
            parent_parts = plot_dir.split('/')
            model_name = "training_animation"
        output_path = f"./03_results/plots/latent/{model_name}_training_progression.gif"
    
    try:
        # Create figure and axis
        fig, ax = plt.subplots(figsize=(15, 10))
        ax.axis('off')
        
        # Load first image to set up the plot
        first_img = mpimg.imread(plot_files[0])
        im = ax.imshow(first_img)
        
        def animate(frame):
            img = mpimg.imread(plot_files[frame])
            im.set_array(img)
            return [im]
        
        # Create animation
        anim = animation.FuncAnimation(
            fig, animate, frames=len(plot_files), 
            interval=1000//fps, blit=True, repeat=True
        )
        
        # Save as GIF
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        anim.save(output_path, writer='pillow', fps=fps)
        plt.close(fig)
        
        print(f"Created training animation at {output_path}")
        
        # Clean up temporary plot files
        try:
            if os.path.exists(plot_dir):
                shutil.rmtree(plot_dir)
                print(f"Cleaned up temporary plots from {plot_dir}")
        except Exception as cleanup_error:
            print(f"Warning: Could not clean up temporary plots: {cleanup_error}")
        
    except Exception as e:
        print(f"Failed to create GIF animation: {e}")
        print(f"Plots are available individually in {plot_dir}")
