import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import os
import json
import argparse
import logging
from tqdm import tqdm
import time
import random
from collections import defaultdict
import torch.nn.functional as F
import transformers

def setup_logging(log_file="neuron_analysis.log", log_level=logging.INFO):
    """Configure logging settings"""
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class MultiGPUGradientBasedAnalyzer:
    def __init__(self, model_path, device="auto", seed=42, logger=None):
        """
        Initialize multi-GPU gradient-based neuron analyzer
        
        Args:
            model_path: Model path
            device: Computing device ("auto" for multi-GPU)
            seed: Random seed
            logger: Logger instance
        """
        self.logger = logger or logging.getLogger(__name__)

        set_seed(seed)
        self.seed = seed
        
        self.device = device
        self.num_gpus = torch.cuda.device_count()
        self.logger.info(f"Available GPUs: {self.num_gpus}")

        self._load_model_multi_gpu(model_path)

        self.layer_activations = {}
        self.layer_gradients = {}
        self.hooks = []
        self.hooks_setup = False
        
    def _load_model_multi_gpu(self, model_path):
        """Load model with multi-GPU support"""
        self.logger.info(f"Loading model with multi-GPU: {model_path}")
        
        self.logger.info(f"Transformers version: {transformers.__version__}")
        self.logger.info(f"PyTorch version: {torch.__version__}")
        self.logger.info(f"CUDA available: {torch.cuda.is_available()}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float32,  
            trust_remote_code=True,
            device_map="auto", 
            low_cpu_mem_usage=True,
            max_memory={i: "80GB" for i in range(self.num_gpus)}  
        )

        self.model.train()  
        self.logger.info("Multi-GPU model loading completed")

        self._print_device_map()
        
    def _print_device_map(self):
        """Print model device allocation"""
        self.logger.info("Model device allocation:")
        if hasattr(self.model, 'hf_device_map'):
            for layer_name, device in self.model.hf_device_map.items():
                self.logger.info(f"  {layer_name}: {device}")
        else:
            self.logger.info("  Device map not available")
    
    def _setup_hooks(self):
        """Set up hooks to collect neuron activations and gradients"""
        if self.hooks_setup:
            return

        for hook in self.hooks:
            hook.remove()
        self.hooks = []

        def get_activation(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    output_with_grad = output.clone().detach().requires_grad_(True)
                    self.layer_activations[name] = {
                        'activation': output_with_grad,
                        'device': output_with_grad.device
                    }
                    return output_with_grad
            return hook

        module_count = 0
        
        for name, module in self.model.named_modules():
            if any(keyword in name.lower() for keyword in ['mlp', 'ffn', 'feed_forward', 'attention', 'self_attn']):
                if not list(module.children()) or isinstance(module, (torch.nn.Linear, torch.nn.LayerNorm)):
                    self.hooks.append(module.register_forward_hook(get_activation(name)))
                    module_count += 1
        
        self.logger.info(f"Set up hooks for {module_count} key modules")
        self.hooks_setup = True
    
    def _sync_data_across_gpus(self, data_dict):
        synced_data = {}
        
        for name, data_info in data_dict.items():
            data = data_info['activation'] if 'activation' in data_info else data_info['gradient']
            device = data_info['device']
            synced_data[name] = data.cpu()
            
        return synced_data
    
    def analyze_gradient_importance(self, input_text, output_file="neuron_gradient_importance.json"):
        """
        Analyze neuron importance based on gradients with respect to perplexity
        
        Args:
            input_text: Input text
            output_file: Output file name
            
        Returns:
            List of neurons sorted by gradient value (descending)
        """
        self.logger.info("Starting multi-GPU gradient-based neuron importance analysis")
        self._setup_hooks()
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
        first_param = next(self.model.parameters())
        input_device = first_param.device
        input_ids = input_ids.to(input_device)
        
        self.logger.info(f"Input text length: {len(input_ids[0])} tokens")
        self.logger.info(f"Input device: {input_device}")
        
        labels = input_ids.clone()

        self.layer_activations.clear()
        self.layer_gradients.clear()
        
        self.logger.info("Computing forward pass and collecting activations")

        outputs = self.model(input_ids, labels=labels)
        loss = outputs.loss  
        perplexity = torch.exp(loss)
        
        self.logger.info(f"Loss: {loss.item():.4f}, Perplexity: {perplexity.item():.4f}")

        for name, act_info in self.layer_activations.items():
            activation = act_info['activation']
            if activation.requires_grad:
                activation.register_hook(
                    lambda grad, name=name: self.layer_gradients.update({
                        name: {'gradient': grad.detach(), 'device': grad.device}
                    })
                )
        
        self.logger.info("Computing backward pass to get gradients")

        loss.backward()

        activations = self._sync_data_across_gpus(self.layer_activations)
        gradients = self._sync_data_across_gpus(self.layer_gradients)
        
        self.logger.info(f"Collected activations from {len(activations)} modules")
        self.logger.info(f"Collected gradients from {len(gradients)} modules")

        neuron_importance = []
        
        self.logger.info("Computing neuron importance based on gradient values")
        for layer_name in tqdm(gradients.keys(), desc="Processing layers"):
            if layer_name in gradients:
                gradient = gradients[layer_name]

                if gradient.dim() >= 3:  
                    batch_size, seq_len, hidden_size = gradient.shape[0], gradient.shape[1], gradient.shape[-1]
                    
                    for batch_idx in range(batch_size):
                        for seq_idx in range(seq_len):
                            for hidden_idx in range(hidden_size):
                                grad_value = gradient[batch_idx, seq_idx, hidden_idx].item()
                                
                                if abs(grad_value) > 1e-8: 
                                    neuron_importance.append({
                                        "layer_name": layer_name,
                                        "batch_idx": int(batch_idx),
                                        "seq_idx": int(seq_idx),
                                        "hidden_idx": int(hidden_idx),
                                        "activation_diff": float(grad_value) 
                                    })
                
                elif gradient.dim() == 2: 
                    dim1_size, dim2_size = gradient.shape
                    
                    for i in range(dim1_size):
                        for j in range(dim2_size):
                            grad_value = gradient[i, j].item()
                            
                            if abs(grad_value) > 1e-8:
                                neuron_importance.append({
                                    "layer_name": layer_name,
                                    "dim1_idx": int(i),
                                    "dim2_idx": int(j),
                                    "activation_diff": float(grad_value)
                                })
                
                else:  
                    flat_gradient = gradient.view(-1)
                    
                    for idx in range(flat_gradient.size(0)):
                        grad_value = flat_gradient[idx].item()
                        
                        if abs(grad_value) > 1e-8:
                            neuron_importance.append({
                                "layer_name": layer_name,
                                "flat_idx": int(idx),
                                "activation_diff": float(grad_value)
                            })
        
        self.logger.info(f"Found {len(neuron_importance)} neurons with meaningful gradients")
 
        self.logger.info("Sorting neurons by gradient value (descending)")
        sorted_neurons = sorted(neuron_importance, key=lambda x: x["activation_diff"], reverse=True)

        self.logger.info(f"Saving results to {output_file}")
        result_data = {
            "method": "multi_gpu_gradient_based_analysis",
            "input_text": input_text,
            "noise_scale": None,
            "num_samples": 1,
            "num_gpus_used": self.num_gpus,
            "loss": float(loss.item()),
            "perplexity": float(perplexity.item()),
            "total_neurons": len(sorted_neurons),
            "top_neurons": sorted_neurons[:10000]
        }
        
        with open(output_file, 'w') as f:
            json.dump(result_data, f, indent=2)
        
        self.logger.info("Multi-GPU gradient-based analysis completed")
        self.logger.info(f"Total discovered neurons: {len(sorted_neurons)}")
        self.logger.info(f"Used {self.num_gpus} GPUs for computation")

        self.logger.info("Top 10 neurons with highest gradient values:")
        for i, neuron in enumerate(sorted_neurons[:10]):
            neuron_info = f"Rank {i+1}: Layer {neuron['layer_name']}, Gradient value: {neuron['activation_diff']:.8f}"
            for k, v in neuron.items():
                if k not in ['layer_name', 'activation_diff']:
                    neuron_info += f", {k}: {v}"
            self.logger.info(neuron_info)
        
        return sorted_neurons

def main():
    parser = argparse.ArgumentParser(description="Multi-GPU Gradient-Based Neuron Importance Analysis")
    parser.add_argument("--model_path", type=str, required=True, help="Model path")
    parser.add_argument("--input_text", type=str, required=True, help="Input text")
    parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
    parser.add_argument("--output_file", type=str, default="neuron_gradient_importance.json", help="Output file name")
    parser.add_argument("--log_file", type=str, default="neuron_gradient_analysis.log", help="Log file name")
    parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level")
    parser.add_argument("--device", type=str, default="auto", help="Computing device (use 'auto' for multi-GPU)")
    
    args = parser.parse_args()

    log_level = getattr(logging, args.log_level.upper())
    logger = setup_logging(args.log_file, log_level)
    
    logger.info("=== Multi-GPU Gradient-Based Neuron Importance Analyzer ===")
    logger.info(f"Model path: {args.model_path}")
    logger.info(f"Input text: {args.input_text[:100]}...")
    logger.info(f"Random seed: {args.seed}")
    logger.info(f"Device: {args.device}")
    logger.info(f"Available GPUs: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")
    
    analyzer = MultiGPUGradientBasedAnalyzer(
        model_path=args.model_path, 
        device=args.device, 
        seed=args.seed, 
        logger=logger
    )

    neurons = analyzer.analyze_gradient_importance(
        input_text=args.input_text,
        output_file=args.output_file
    )
    
    logger.info("=== Multi-GPU Gradient-Based Analysis Completed ===")
    logger.info(f"Discovered {len(neurons)} neurons")
    logger.info(f"Results saved to: {args.output_file}")
    logger.info(f"Log saved to: {args.log_file}")

if __name__ == "__main__":
    main()
