import argparse
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Add parent directory to path for imports
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)

# Import the GradientLocalizer and GradientExtractor
from _Localizer.Localizer import Localizer
from _Localizer.GradientExtractor import GradientExtractor
from _dattri.benchmark.load import load_benchmark
from _dattri.benchmark.utils import SubsetSampler
from _dattri.benchmark.models.MusicTransformer.utilities.constants import TOKEN_PAD

def parse_args():
    parser = argparse.ArgumentParser(description="Localize important gradient components for MLP on MNIST")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="device to be used"
    )
    parser.add_argument(
        "--localize",
        type=int,
        default=100,
        help="Total number of localized parameters across the model."
    )
    parser.add_argument(
        "--epoch",
        type=int,
        default=2000,
        help="Number of epochs for learning the localized parameters."
    )
    parser.add_argument(
        "--log_interval",
        type=int,
        default=50,
        help="Interval for logging the training process."
    )
    parser.add_argument(
        "--loc_n",
        type=int,
        default=200,
        help="Number of training samples used for training localized parameters."
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use."
    )
    parser.add_argument(
        "--regularization",
        type=float,
        default=5.0,
        help="Lambda for the regularization term."
    )
    parser.add_argument(
        "--early_stop",
        type=float,
        default=0.9,
        help="The correlation threshold for early stopping."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./Localize/",
        help="Directory to save the localization results"
    )
    parser.add_argument(
        "--cpu_offload",
        action="store_true",
        help="Whether to offload tensors to CPU to save GPU memory"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size for data loading"
    )

    args = parser.parse_args()
    return args

def setup_logger():
    import logging
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    return logger

def create_dataloaders(model_details, args):
    """Create train and test dataloaders for localization"""
    # Get the full training dataset
    train_dataset = model_details["train_dataset"]

    # Create a smaller subset for localization training
    train_indices = list(range(args.loc_n))
    loc_train_sampler = SubsetSampler(train_indices[:int(args.loc_n * 0.8)])
    loc_test_sampler = SubsetSampler(train_indices[int(args.loc_n * 0.8):args.loc_n])

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=loc_train_sampler
    )

    test_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=loc_test_sampler
    )

    return train_dataloader, test_dataloader

def main():
    args = parse_args()
    logger = setup_logger()

    # Create output directory
    output_dir = f"{args.output_dir}/mask_{args.localize}"
    os.makedirs(output_dir, exist_ok=True)

    # Load MNIST + MLP benchmark
    logger.info(f"Loading MAESTRO + MusicTransformer benchmark...")
    model_details, _ = load_benchmark(model="musictransformer", dataset="maestro", metric="lds")

    # Get model and move to device
    model = model_details["model"]

    if args.device.startswith("cuda"):
        if not torch.cuda.is_available():
            raise ValueError("CUDA is not available. Please check your CUDA installation.")
        device = torch.device(args.device)
        torch.cuda.set_device(device)
    else:
        assert args.device == "cpu", "Invalid device. Choose from 'cuda' or 'cpu'."
        device = torch.device("cpu")

    model = model.to(device)
    model.eval()  # Set model to evaluation mode

    # Calculate total parameter count
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model has {total_params} trainable parameters")

    # Verify localization target is less than total parameters
    if args.localize > total_params:
        logger.warning(f"Localization target ({args.localize}) exceeds total parameters ({total_params}). Setting to {total_params}.")
        args.localize = total_params

    # Create dataloaders for localization
    train_dataloader, test_dataloader = create_dataloaders(model_details, args)

    # Define custom loss function for gradient extraction
    def custom_loss_fn(outputs, batch):
        """Custom loss function for gradient extraction based on loss_trak from file 1"""
        # Unpack the batch
        x, y = batch

        # Ensure inputs are on the correct device
        x = x.to(device)
        y = y.to(device)

        # For the gradient extractor, outputs are already from the model
        # Get the last token prediction and target
        output_last = outputs[:, -1, :]
        y_last = y[:, -1]

        # Use CrossEntropyLoss with ignore_index set to TOKEN_PAD and reduction to 'none'
        loss_fn = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduction='none')

        # Calculate negative log probability (logp)
        logp = -loss_fn(output_last, y_last)

        # Convert to logit scale as in loss_trak
        # Using clamp to prevent numerical issues with log(1-exp(logp))
        logp_safe = torch.clamp(logp, max=0.0)  # Ensure logp <= 0 for numerical stability
        logit_func = logp_safe - torch.log(1 - torch.exp(logp_safe))

        # Return the mean for the batch to get a scalar loss
        return torch.mean(logit_func)

    # Initialize gradient extractor
    logger.info("Initializing gradient extractor...")
    extractor = GradientExtractor(
        model=model,
        device=device,
        cpu_offload=args.cpu_offload
    )

    # Extract gradients for the entire model at once
    logger.info("Extracting gradients for the entire model...")
    train_gradients, test_gradients = extractor.extract_gradients(
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        custom_loss_fn=custom_loss_fn
    )

    if train_gradients is None or test_gradients is None:
        logger.error("Failed to extract gradients. Exiting.")
        return

    # Get gradient tensors
    train_grad_tensor = train_gradients['gradient']
    test_grad_tensor = test_gradients['gradient']

    # Get gradient dimension (should be total parameters)
    gradient_dim = train_grad_tensor.shape[1]

    logger.info(f"Extracted gradients - Shape: {train_grad_tensor.shape}, Total parameters: {gradient_dim}")

    # Initialize the gradient localizer for the entire model
    logger.info("Training the gradient mask optimizer...")
    localizer = Localizer(
        gradient_dim=gradient_dim,
        lambda_reg=args.regularization,
        lr=args.learning_rate,
        min_active_gradient=args.localize,
        max_active_gradient=args.localize,  # Exact target, no flexibility
        device=device,
        logger=logger
    )

    # Train the localizer
    eval_metrics = localizer.train(
        train_gradients=train_grad_tensor,
        test_gradients=test_grad_tensor,
        batch_size=3000,
        num_epochs=args.epoch,
        log_every=args.log_interval,
        correlation_threshold=args.early_stop
    )

    # Get important indices
    logger.info(f"Retrieving important gradient indices...")
    important_indices = localizer.get_important_indices(
        threshold=0.5,
        min_count=args.localize
    )

    # Calculate sparsity
    effective_params = len(important_indices['gradient'])
    sparsity = 100 - (effective_params / gradient_dim * 100)

    logger.info(f"Results:")
    logger.info(f"Gradient mask: {effective_params}/{gradient_dim} parameters ({effective_params/gradient_dim*100:.2f}%)")
    logger.info(f"Sparsity achieved: {sparsity:.2f}%")
    logger.info(f"Correlation preserved: {eval_metrics['avg_rank_correlation']:.4f}")

    # Convert to tensor
    active_indices_tensor = torch.tensor(important_indices['gradient'], dtype=torch.long)

    # Save the important indices
    results = {
        "active_indices": active_indices_tensor,
        "total_params": gradient_dim,
        "active_params": effective_params,
        "sparsity": sparsity,
        "correlation": eval_metrics['avg_rank_correlation'],
        "args": vars(args)
    }

    # Get parameter mapping (for analysis)
    param_map = extractor.get_param_to_indices_map()
    results["param_map"] = param_map

    # Analyze which parameters were selected
    logger.info("\nParameter-wise analysis:")
    param_stats = {}
    for name, (start_idx, end_idx) in param_map.items():
        # Count how many indices in this parameter range are active
        param_active_indices = [idx for idx in important_indices['gradient'] if start_idx <= idx < end_idx]
        param_active_count = len(param_active_indices)
        param_total = end_idx - start_idx
        param_sparsity = 100 - (param_active_count / param_total * 100) if param_total > 0 else 0

        param_stats[name] = {
            "active": param_active_count,
            "total": param_total,
            "sparsity": param_sparsity,
            "percentage": param_active_count / effective_params * 100
        }

        if param_active_count > 0:
            logger.info(f"Parameter {name}: {param_active_count}/{param_total} active "
                      f"({param_active_count/param_total*100:.2f}% dense, "
                      f"{param_active_count/effective_params*100:.2f}% of all active)")

    # Add parameter stats to results
    results["param_stats"] = param_stats

    # Save results
    output_file = os.path.join(output_dir, 'result.pt')
    torch.save(results, output_file)
    logger.info(f"Results saved to {output_file}")

    # Clear memory
    del train_grad_tensor, test_grad_tensor, localizer
    del train_gradients, test_gradients
    torch.cuda.empty_cache()

    logger.info("Localization completed successfully!")

if __name__ == "__main__":
    main()