import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

def gemma_generate_hooks_and_model(n, device, head_outputs, sampling="uniform", num_samples=10):
    """
    Generates samples from the Gemma model and registers hooks to extract the head outputs.
    
    Note: Gemma models are gated and require authentication. You need to:
    1. Request access at https://huggingface.co/google/gemma-2b
    2. Log in with: huggingface-cli login
    
    Args:
        n: The length of the input sequence.
        device: The device to run the model on.
        head_outputs: A dictionary to store the head outputs.
        sampling: The sampling method to use.
        num_samples: The number of samples to generate.

    Returns:
        model: The Gemma model.
        samples: The samples.
        attention_mask: The attention mask.
        params: The parameters of the model.
    """
    # Try different Gemma models in order of preference
    model_names = [
        "google/gemma-2-2b",     # Base model
    ]
    
    model = None
    tokenizer = None
    model_name = None
    
    for candidate_name in model_names:
        try:
            print(f"Attempting to load Gemma model: {candidate_name}")
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(candidate_name)
            print("Tokenizer Loaded.")
            
            # Load model with quantization for memory efficiency
            try:
                # Try with 8-bit quantization first
                # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
                model = AutoModelForCausalLM.from_pretrained(
                    candidate_name, 
                    # quantization_config=quantization_config,
                    device_map="auto"
                )
                print("Model Loaded")
            except Exception as e:
                print(f"Model loading failed: {e}")
                print("Falling back to full precision...")
                # Fall back to full precision
                model = AutoModelForCausalLM.from_pretrained(
                    candidate_name,
                    device_map="auto"
                )
                print("Model Loaded with full precision.")
            
            model.eval()
            model_name = candidate_name
            print(f"Successfully loaded {candidate_name}")
            break
            
        except Exception as e:
            print(f"Failed to load {candidate_name}: {e}")
            continue
    
    if model is None:
        error_msg = "Could not load any Gemma model."
        raise ValueError(error_msg)

    # Ensure the tokenizer has a pad token
    assert tokenizer is not None, "Tokenizer should be loaded at this point"
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
        model.config.pad_token_id = tokenizer.eos_token_id

    def register_hooks():
        for layer_id, block in enumerate(model.model.layers):
            def make_hook(layer_id):
                def hook(module, input, output):
                    # output shape: (batch, seq_len, hidden_dim)
                    if isinstance(output, tuple):
                        output = output[0]
                    B, T, C = output.size()
                    n_heads = model.config.num_attention_heads
                    head_dim = C // n_heads
                    # Reshape to view outputs per head: (batch, n_heads, seq_len, head_dim)
                    output_heads = output.view(B, T, n_heads, head_dim).permute(0, 2, 1, 3)
                    head_outputs[layer_id] = output_heads
                        
                return hook
            block.self_attn.register_forward_hook(make_hook(layer_id))
    register_hooks()

    n_heads = model.config.num_attention_heads
    n_layers = model.config.num_hidden_layers
    vocab_size = model.config.vocab_size
    params = (n_heads, n_layers, vocab_size)

    samples = []
    attention_mask = None

    # Ensure tokenizer is available for sampling
    assert tokenizer is not None, "Tokenizer should be loaded at this point"

    if sampling == "model":

        # Disable torchdynamo because it causes issues with model sampling (unclear why)
        import os
        original_torchdynamo = os.environ.get('TORCHDYNAMO_DISABLE', '0')
        os.environ['TORCHDYNAMO_DISABLE'] = '1'

        try:
            for i in range(num_samples):
                print(f"Generating sample {i+1} of {num_samples}")
                # Generate sequences using the model
                # Start with a simple prompt
                input_text = "The weather today is"
                input_ids = tokenizer(input_text, return_tensors="pt").to(device)
                
                generated_ids = model.generate(
                    **input_ids,
                    max_length=n,
                    do_sample=True,
                    top_k=50,
                    temperature=0.7,
                    pad_token_id=tokenizer.pad_token_id
                )
                # Truncate to desired length
                generated_ids = generated_ids[:, :n]

                # Pad if too short
                if generated_ids.size(1) < n:
                    pad_length = n - generated_ids.size(1)
                    pad_token = tokenizer.eos_token_id
                    pad = torch.full((generated_ids.size(0), pad_length), pad_token, 
                                dtype=generated_ids.dtype, device=generated_ids.device)
                    generated_ids = torch.cat([generated_ids, pad], dim=1)

                x = model.model.embed_tokens(generated_ids).detach().to(device).requires_grad_()
                samples.append(x)

                # Create attention mask: 1 for real tokens, 0 for padding
                attention_mask = (generated_ids != tokenizer.pad_token_id).long().to(device)

        finally:
            # Reset the environment variable.
            os.environ['TORCHDYNAMO_DISABLE'] = original_torchdynamo

    elif sampling == "uniform":
        for _ in range(num_samples):
            input_ids = torch.randint(0, vocab_size, (1, n), device=device)
            x = model.model.embed_tokens(input_ids).to(device).requires_grad_()
            samples.append(x)
            # For uniform sampling, all tokens are "real"
            attention_mask = torch.ones((1, n), dtype=torch.long, device=device)

    else:
        raise ValueError("Invalid sampling method. Use 'model' or 'uniform'.")

    return model, samples, attention_mask, params
