"""Convert Llama 3.2 1B pretrained weights from PyTorch to Flax format."""

import argparse
from pathlib import Path
from typing import Dict

import jax
import jax.numpy as jnp
import numpy as np
import torch
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import AutoModelForCausalLM
import pickle

from fma_llama.model.config import LlamaConfig
from fma_llama.model.llama import LlamaForCausalLM


def initialize_flax_model(config: LlamaConfig):
    """Initialize Flax model to get parameter structure without execution."""
    model = LlamaForCausalLM(config)

    # Create dummy input with abstract shape
    from jax import ShapeDtypeStruct
    dummy_input_shape = ShapeDtypeStruct((1, 10), jnp.int32)

    # Initialize parameters using eval_shape to avoid actual computation
    rng = jax.random.PRNGKey(0)

    # This will trace the init to get shapes without executing operations
    params = jax.eval_shape(model.init, rng, dummy_input_shape)

    return model, params


def convert_pytorch_state_dict_to_flax(
    pytorch_state_dict: Dict,
    config: LlamaConfig,
    flax_params: Dict,
) -> Dict:
    """Convert PyTorch state dict to Flax parameters.

    Args:
        pytorch_state_dict: PyTorch model state dict
        config: Llama config
        flax_params: Initialized Flax parameters (for structure)

    Returns:
        Flax parameters dictionary with converted weights
    """
    # Flatten both dictionaries for easier mapping
    flat_flax = flatten_dict(flax_params['params'], sep='/')

    print(f"\nFlax parameter structure (sample):")
    for i, key in enumerate(list(flat_flax.keys())[:10]):
        param = flat_flax[key]
        # Handle both concrete arrays and abstract shapes from eval_shape
        if hasattr(param, 'shape'):
            shape = param.shape
        elif hasattr(param, 'value') and hasattr(param.value, 'shape'):
            # Partitioned objects have .value attribute
            shape = param.value.shape
        else:
            shape = str(type(param))
        print(f"  {key}: {shape}")

    # Convert PyTorch weights
    converted = {}

    for name, param in pytorch_state_dict.items():
        param_np = param.cpu().numpy()

        # Remove 'model.' prefix if present
        if name.startswith('model.'):
            name = name[6:]

        # Map PyTorch names to Flax structure
        # Note: All model parameters are under LlamaModel_0 scope
        if name.startswith('embed_tokens.weight'):
            # Embedding layer
            converted['LlamaModel_0/Embed_0/embedding'] = param_np

        elif name.startswith('layers.'):
            # Parse layer index and component
            parts = name.split('.')
            layer_idx = int(parts[1])

            if 'self_attn' in name:
                # Attention layers
                if 'q_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/FMAAttention_0/Dense_0/kernel'
                    converted[key] = param_np.T
                elif 'k_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/FMAAttention_0/Dense_1/kernel'
                    converted[key] = param_np.T
                elif 'v_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/FMAAttention_0/Dense_2/kernel'
                    converted[key] = param_np.T
                elif 'o_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/FMAAttention_0/Dense_3/kernel'
                    converted[key] = param_np.T

            elif 'mlp' in name:
                # MLP layers
                if 'gate_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/LlamaMLP_0/Dense_0/kernel'
                    converted[key] = param_np.T
                elif 'up_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/LlamaMLP_0/Dense_1/kernel'
                    converted[key] = param_np.T
                elif 'down_proj.weight' in name:
                    key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/LlamaMLP_0/Dense_2/kernel'
                    converted[key] = param_np.T

            elif 'input_layernorm.weight' in name:
                key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/RMSNorm_0/weight'
                converted[key] = param_np
            elif 'post_attention_layernorm.weight' in name:
                key = f'LlamaModel_0/LlamaDecoderLayer_{layer_idx}/RMSNorm_1/weight'
                converted[key] = param_np

        elif name == 'norm.weight':
            # Final layer norm (inside LlamaModel)
            converted['LlamaModel_0/RMSNorm_0/weight'] = param_np

        elif name.startswith('lm_head.weight'):
            # LM head (outside LlamaModel, in LlamaForCausalLM)
            converted['Dense_0/kernel'] = param_np.T

    print(f"\nConverted {len(converted)} parameters")
    print(f"Expected {len(flat_flax)} parameters")

    # Unflatten and return
    return {'params': unflatten_dict(converted, sep='/')}


def main():
    parser = argparse.ArgumentParser(description='Convert Llama weights to Flax')
    parser.add_argument(
        '--model_name',
        type=str,
        default='meta-llama/Llama-3.2-1B',
        help='HuggingFace model name',
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default='checkpoints/llama-3.2-1b-flax',
        help='Output directory for converted weights',
    )
    parser.add_argument(
        '--use_auth_token',
        action='store_true',
        help='Use HuggingFace auth token for gated models',
    )

    args = parser.parse_args()

    # Load config
    print(f"Loading config from {args.model_name}")
    config = LlamaConfig.from_pretrained(args.model_name)

    # Disable FMA for initial testing
    config.use_fma_attention = False

    # Create dummy mesh for model initialization (needed by sharding constraints)
    from jax.sharding import Mesh
    from jax.experimental import mesh_utils
    devices = mesh_utils.create_device_mesh((1, 1))
    mesh = Mesh(devices, axis_names=('data', 'model'))
    jax.set_mesh(mesh)
    print(f"Created dummy mesh for initialization: {mesh}")

    print(f"Initializing Flax model to get parameter structure...")
    flax_model, init_params = initialize_flax_model(config)

    print(f"Loading PyTorch model: {args.model_name}")

    # Load PyTorch model
    if args.use_auth_token:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float32,
            token=True,  # Updated from use_auth_token
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float32,
        )

    # Get state dict
    pytorch_state_dict = model.state_dict()

    print(f"Converting {len(pytorch_state_dict)} PyTorch parameters to Flax format...")

    # Convert to Flax format
    flax_params = convert_pytorch_state_dict_to_flax(pytorch_state_dict, config, init_params)

    # Create output directory
    output_path = Path(args.output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Save Flax parameters
    params_file = output_path / "flax_params.pkl"
    print(f"Saving Flax parameters to {params_file}")

    with open(params_file, 'wb') as f:
        pickle.dump(flax_params, f)

    # Save config
    config_file = output_path / "config.pkl"
    print(f"Saving config to {config_file}")

    with open(config_file, 'wb') as f:
        pickle.dump(config, f)

    print("Conversion complete!")
    print(f"Converted parameters saved to {output_path}")

    # Print parameter shapes for verification
    print("\nParameter shapes:")
    for name, param in flatten_dict(flax_params).items():
        if isinstance(param, dict):
            for k, v in param.items():
                print(f"  {name}.{k}: {v.shape if hasattr(v, 'shape') else type(v)}")
        else:
            print(f"  {name}: {param.shape if hasattr(param, 'shape') else type(param)}")


if __name__ == '__main__':
    main()
