#!/usr/bin/env python
"""
ARQ rotation utilities with real activation capture
This properly extracts activations from the model during calibration
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, List
import logging
from tqdm import tqdm
import utils
import model_utils
from arq.butterfly_transform_fake import FakeButterflyTransform
from arq.rotation_learning_sgd import RotationLearnerSGD
from arq.losses import MultiObjectiveLoss


def capture_model_activations(model, calibration_loader, args, device='cuda'):
    """
    Capture real activations from the model's embedding layer
    
    Args:
        model: The model to capture activations from
        calibration_loader: DataLoader with calibration data
        args: Arguments containing ARQ hyperparameters
        device: Device to use
        
    Returns:
        List of activation tensors
    """
    logging.info("Capturing real model activations...")
    
    # Get model type and config
    model_type = model_utils.model_type_extractor(model)
    hidden_size = model.config.hidden_size
    
    # Storage for activations
    activations = []
    max_samples = min(args.arq_calib_batches, len(calibration_loader))
    
    # Hook to capture embeddings output
    def hook_fn(module, input, output):
        # Store a copy of the output
        activations.append(output.detach().cpu())
    
    # Register hook on embedding layer
    if model_type == model_utils.LLAMA_MODEL:
        embed_layer = model.model.embed_tokens
    elif model_type == model_utils.OPT_MODEL:
        embed_layer = model.model.decoder.embed_tokens
    else:
        raise ValueError(f'Unknown model type {model_type}')
    
    # Move embedding layer to device temporarily
    embed_layer = embed_layer.to(device)
    hook_handle = embed_layer.register_forward_hook(hook_fn)
    
    # Disable gradient computation
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(tqdm(calibration_loader, total=max_samples, desc="Capturing activations")):
            if i >= max_samples:
                break
            
            # Extract input tensors
            if isinstance(batch, dict):
                input_ids = batch['input_ids']
            else:
                input_ids = batch[0] if isinstance(batch, (list, tuple)) else batch
            
            # Move to device and get embeddings
            input_ids = input_ids.to(device)
            _ = embed_layer(input_ids)
    
    # Remove hook and move layer back
    hook_handle.remove()
    embed_layer = embed_layer.cpu()
    
    # Move activations to device and reshape
    logging.info(f"Captured {len(activations)} batches of activations")
    processed_activations = []
    for act in activations:
        act = act.to(device)
        # Flatten batch and sequence dimensions
        act = act.reshape(-1, hidden_size)
        processed_activations.append(act)
    
    return processed_activations


def learn_adaptive_rotation_real_acts(
    dim: int,
    device: str,
    model,
    calibration_loader,
    args
) -> torch.Tensor:
    """
    Learn an adaptive rotation matrix using real model activations
    
    Args:
        dim: Dimension of rotation matrix
        device: Device to use
        model: The model (for getting calibration data)
        calibration_loader: Calibration data loader
        args: Arguments containing ARQ hyperparameters
    
    Returns:
        Learned rotation matrix Q
    """
    logging.info(f"Learning adaptive rotation for dimension {dim}")
    
    # Choose transform type
    if hasattr(args, 'arq_butterfly_mode') and args.arq_butterfly_mode:
        init_mode = getattr(args, 'arq_butterfly_init', 'hadamard')
        
        # Check if dimension is power of 2
        is_pow2 = (dim & (dim - 1)) == 0
        
        if is_pow2:
            # Use standard fake butterfly for power-of-2 dimensions
            transform = FakeButterflyTransform(dim, device=device, init_mode=init_mode)
            logging.info(f"Using Fake Butterfly transform with {init_mode} initialization")
        else:
            # For 5120, use fixed composite butterfly (40×128 like QuaRot)
            if dim == 5120:
                from arq.butterfly_transform_composite_fixed import CompositeButterflyFixed
                transform = CompositeButterflyFixed(dim, device=device, init_mode=init_mode)
                logging.info(f"Using Fixed Composite Butterfly (40×128) for dim={dim} with {init_mode} initialization")
            else:
                # For other non-power-of-2 dimensions, raise error (not supported yet)
                raise ValueError(f"Butterfly transform not supported for dimension {dim}. Only power-of-2 and 5120 are supported.")
        
        logging.info(f"  Parameters: {transform.count_parameters()} (compression: {dim*dim/transform.count_parameters():.1f}x)")
    else:
        # If not butterfly mode, just use standard fake butterfly as fallback
        transform = FakeButterflyTransform(dim, device=device, init_mode='hadamard')
        logging.info(f"Using standard Fake Butterfly transform (fallback)")
    
    # Capture real activations from model
    calibration_data = capture_model_activations(model, calibration_loader, args, device)
    
    # Check if we should use SGD optimizer
    use_sgd = hasattr(args, 'arq_optimizer') and args.arq_optimizer == 'sgd'
    
    if use_sgd:
        # Use SGD-based learner
        logging.info(f"Using SGD optimizer with lr={args.arq_lr}, momentum={getattr(args, 'arq_momentum', 0.9)}")
        
        # Setup loss function
        loss_fn = MultiObjectiveLoss(
            lambda_quant=getattr(args, 'arq_lambda_quant', 1.0),
            lambda_ortho=args.arq_lambda_ortho,
            lambda_entropy=args.arq_lambda_entropy,
            lambda_sparsity=getattr(args, 'arq_lambda_sparsity', 0.0),
            bits=args.w_bits,
            sym=True,
            sparsity_type=getattr(args, 'arq_sparsity_type', 'l1'),
            # Uniformity regularization
            gamma_uni=getattr(args, 'arq_gamma_uni', 0.0),
            curriculum=getattr(args, 'arq_curriculum', False),
            warmup_steps=getattr(args, 'arq_dist_warmup', 50),
            total_steps=args.arq_steps
        )
        
        # Create SGD learner
        learner = RotationLearnerSGD(
            transform=transform,
            loss_fn=loss_fn,
            lr=args.arq_lr,
            momentum=getattr(args, 'arq_momentum', 0.9),
            steps=args.arq_steps,
            lr_scheduler=getattr(args, 'arq_lr_scheduler', 'none'),
            lr_min=getattr(args, 'arq_lr_min', 0.01),
            warmup_steps=getattr(args, 'arq_warmup_steps', 0)
        )
        
        # Create data loader
        all_data = torch.cat(calibration_data, dim=0)
        dataset = torch.utils.data.TensorDataset(all_data)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=min(256, all_data.shape[0] // 4),
            shuffle=True
        )
        
        # Learn rotation
        logging.info(f"Learning rotation with {args.arq_steps} steps...")
        Q = learner.learn(data_loader)
        
        # Create dummy history for compatibility
        history = {
            'total_loss': [0.0],  # Will be updated by learner
            'quant_loss': [0.0],
            'ortho_loss': [0.0],
            'entropy_loss': [0.0]
        }
    else:
        # Use original Adam-based learner
        rotation = LearnableRotationFixed(transform)
        
        logging.info(f"Learning rotation with {args.arq_steps} steps...")
        history = rotation.learn_rotation(
            calibration_data,
            num_steps=args.arq_steps,
            lr=args.arq_lr,
            constraint_type='soft',  # We handle orthogonality in the transform itself
            bits=args.w_bits,
            lambda_quant=1.0,  # Add this!
            lambda_ortho=args.arq_lambda_ortho,
            lambda_entropy=args.arq_lambda_entropy
        )
        
        Q = rotation.get_matrix()
    
    # Log final statistics
    logging.info(f"\nFinal losses:")
    logging.info(f"  Total: {history['total_loss'][-1]:.6f}")
    logging.info(f"  Quantization: {history['quant_loss'][-1]:.6f}")
    logging.info(f"  Orthogonality: {history['ortho_loss'][-1]:.6f}")
    logging.info(f"  Entropy: {history['entropy_loss'][-1]:.6f}")
    
    # Verify orthogonality
    with torch.no_grad():
        QTQ = Q.t() @ Q
        I = torch.eye(dim, device=device, dtype=Q.dtype)
        ortho_error = torch.norm(QTQ - I).item()
        logging.info(f"  Final orthogonality error: ||Q^T Q - I|| = {ortho_error:.8f}")
        
        # Check distance from Hadamard (only if transform has hadamard_base)
        if hasattr(transform, 'hadamard_base'):
            H = transform.hadamard_base
            distance = torch.norm(Q - H).item()
            logging.info(f"  Distance from Hadamard: {distance:.6f}")
    
    return Q


def integrate_arq_rotation_real_acts(args):
    """
    Integration function to add ARQ-specific arguments
    This is called before model loading to set up ARQ parameters
    """
    # ARQ is just another rotation mode
    if not hasattr(args, 'rotate_mode'):
        args.rotate_mode = 'arq'
    
    # Ensure we have ARQ parameters with defaults
    if not hasattr(args, 'arq_steps'):
        args.arq_steps = 20
    if not hasattr(args, 'arq_lr'):
        args.arq_lr = 0.01
    if not hasattr(args, 'arq_transform_type'):
        args.arq_transform_type = 'hybrid'
    if not hasattr(args, 'arq_calib_batches'):
        args.arq_calib_batches = 4
    if not hasattr(args, 'arq_lambda_ortho'):
        args.arq_lambda_ortho = 0.1
    if not hasattr(args, 'arq_lambda_entropy'):
        args.arq_lambda_entropy = 0.01
    
    logging.info(f"ARQ rotation mode enabled with {args.arq_steps} learning steps")
    logging.info(f"Using real model activations for calibration")


def add_arq_args(parser):
    """Add ARQ-specific command line arguments"""
    arq_group = parser.add_argument_group('ARQ (Adaptive Rotation Quantization)')
    
    arq_group.add_argument('--arq_steps', type=int, default=20,
                          help='Number of rotation learning steps')
    arq_group.add_argument('--arq_lr', type=float, default=0.01,
                          help='Learning rate for rotation optimization')
    arq_group.add_argument('--arq_transform_type', type=str, default='hybrid',
                          choices=['simple', 'hybrid'],
                          help='Type of butterfly transform to use')
    arq_group.add_argument('--arq_lambda_ortho', type=float, default=0.1,
                          help='Weight for orthogonality loss')
    arq_group.add_argument('--arq_lambda_entropy', type=float, default=0.01,
                          help='Weight for entropy loss')
    arq_group.add_argument('--arq_optimizer', type=str, default='adam',
                          choices=['adam', 'sgd'],
                          help='Optimizer to use for rotation learning')
    arq_group.add_argument('--arq_momentum', type=float, default=0.9,
                          help='Momentum for SGD optimizer')
    arq_group.add_argument('--arq_init_scale', type=float, default=0.0,
                          help='Scale for initial random perturbation (0 = no perturbation)')
    arq_group.add_argument('--arq_lambda_quant', type=float, default=1.0,
                          help='Weight for quantization loss')
    arq_group.add_argument('--arq_calib_batches', type=int, default=4,
                          help='Number of calibration batches for learning')
    
    # Learning rate scheduler arguments
    arq_group.add_argument('--arq_lr_scheduler', type=str, default='none',
                          choices=['none', 'cosine', 'warmup_cosine', 'linear'],
                          help='Learning rate scheduler type')
    arq_group.add_argument('--arq_lr_min', type=float, default=0.01,
                          help='Minimum learning rate for scheduler')
    arq_group.add_argument('--arq_warmup_steps', type=int, default=0,
                          help='Warmup steps for warmup_cosine scheduler')
    
    # Butterfly mode arguments
    arq_group.add_argument('--arq_butterfly_mode', action='store_true',
                          help='Use fake butterfly transform with O(n log n) parameters')
    arq_group.add_argument('--arq_butterfly_init', type=str, default='hadamard',
                          choices=['identity', 'random', 'hadamard'],
                          help='Initialization mode for butterfly transform')
    
    # Sparsity regularization arguments
    arq_group.add_argument('--arq_lambda_sparsity', type=float, default=0.0,
                          help='Weight for sparsity loss (L1/L2 regularization on angles)')
    arq_group.add_argument('--arq_sparsity_type', type=str, default='l1',
                          choices=['l1', 'l2', 'huber'],
                          help='Type of sparsity regularization')
    
    # Uniformity regularization arguments
    arq_group.add_argument('--arq_gamma_uni', type=float, default=0.0,
                          help='Weight for uniformity loss (encourages even bin usage)')
    arq_group.add_argument('--arq_curriculum', action='store_true',
                          help='Enable curriculum learning for uniformity regularization')
    arq_group.add_argument('--arq_dist_warmup', type=int, default=50,
                          help='Warmup steps before uniformity regularization starts')
    
    return parser