import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import warnings

from hr2r.model.hr2r_config import HR2RConfig
from hr2r.model.recurrent_transformer import HR2RForCausalLM
from hr2r.model.tracker import HR2RTracker
from hr2r.model.utils import IterCountColors, HR2RForCasualLM_generate

# Fix random seed for reproducibility
torch.manual_seed(42)

def main():
    """
    Initializations
    """
    USE_SAVED_MODEL = False
    LOAD_R2R_ROUTER = True

    save_model_name = "./path/to/model"

    iter_decider_name = "PluginNeuralIterDecider"
    input_updater_name = "AdditiveUpdater"
    max_iter = 3
    # device_map = "auto"
    device_map = "cuda:0"

    iter_decider_options = {
        "EntropyIterDecider": {
            "threshold": 1.0,
            "max_iter": max_iter,
        },
        "MLPIterDecider": {
            "topk": 100,
            "hidden_dim": 128,
            "num_layers": 5,
            "threshold": 0.5,
            "max_iter": max_iter,
        },
        "TrivialIterDecider": {},
        "RandomIterDecider": {
            "final_probs": [1, 1, 1],
            "max_iter": max_iter,
        },
        "PluginNeuralIterDecider": {
            "module_cls": "r2r.models.router.HiddenStatesLogitsClassifier",
            "init_kwargs": {
                "logits_size" : 151936,
                "hidden_dims" : [256, 256, 256],
                "expansion_factor" : 4,
                "dropout_rate" : 0.3,
                "use_position_embedding" : False,
                "max_position_embeddings" : 1024,
                "normalize_input" : False,
            },
            "input_mapping": {"logits": "logits", "hidden_states": "hidden_states"},
            "threshold": 0.5,
            "max_iter": max_iter,
            "topk_logits": 100,
        },
    }

    input_updater_options = {
        "AdditiveUpdater": {
            "topk": 100,
        },
        "NeuralUpdater": {
            "embed_dim": 896,
            "topk": 10,
            "hidden_dim": 896,
            "num_layers": 2,
        },
    }    

    if not USE_SAVED_MODEL:
        model_name = "./path/to/model"

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            attn_implementation="sdpa",
        ) # note that we cannot use device_map=auto for simple base model
        config = HR2RConfig(
            embedding_key="model.embed_tokens",
            max_iter=max_iter,
            iter_decider=iter_decider_name,
            iter_decider_kwargs=iter_decider_options[iter_decider_name],
            input_updater=input_updater_name,
            input_updater_kwargs=input_updater_options[input_updater_name],
        )
        hr2r_model = HR2RForCausalLM(
            base_model=base_model, 
            config=config,
            device_map=device_map
        )
        
        if iter_decider_name == "PluginNeuralIterDecider" and LOAD_R2R_ROUTER:
            R2R_router_path = "./path/to/R2R/classifier.pt"
            hr2r_model.hr2r_config = hr2r_model.iter_decider.R2R_update_model_and_config(R2R_router_path, hr2r_model.hr2r_config)

    else:
        tokenizer = AutoTokenizer.from_pretrained(save_model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        override_config = None
        
        hr2r_model = HR2RForCausalLM.from_pretrained(
            save_model_name,
            torch_dtype=torch.bfloat16,
            device_map=device_map,
            attn_implementation="sdpa",
            hr2r_config=override_config,
        )

    # Test save/load functionality before inference
    hr2r_model = test_save_load_functionality(hr2r_model, tokenizer, device_map)
    
    device = hr2r_model.device
    dtype = hr2r_model.dtype
    print(f"Device: {device}, Dtype: {dtype}")

    hr2r_model = hr2r_model.to(dtype=dtype)

    # Attach tracker
    tracker = HR2RTracker(top_k=10)
    tracker.attach(hr2r_model)

    """
    Input and run
    """

    # prepare the model input - batch size = 4
    prompts = [
        "who are you?"
    ]

    # Process each prompt through chat template
    texts = []
    for prompt in prompts:
        messages = [{"role": "user", "content": prompt}]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,  # Switches between thinking and non-thinking modes. Default is True.
        )
        texts.append(text)

    model_inputs = tokenizer(
        texts, return_tensors="pt", padding=True, padding_side="left"
    ).to(device=device)
    batch_size, seq_len = model_inputs.input_ids.shape

    # Define iter count for prefilling
    iter_count = 1 * torch.ones(batch_size, seq_len, dtype=torch.long, device=device)
    iter_count[torch.rand(batch_size, seq_len) < 0.2] += 1
    iter_count[torch.rand(batch_size, seq_len) < 0.2] += 1
    iter_count[torch.rand(batch_size, seq_len) < 0.8] += 1
    iter_count[:, -1] = 1

    print("Initial input:")
    for i in range(batch_size):
        print(f"\nSample {i+1}:")
        print(f"Prompt: {prompts[i][:100]}{'...' if len(prompts[i]) > 100 else ''}")
    print(f"Input IDs shape: {model_inputs.input_ids.shape}")

    print(IterCountColors.get_legend())

    # Use the generation function with sampling
    output_tokens, final_texts = HR2RForCasualLM_generate(
        hr2r_model=hr2r_model,
        tokenizer=tokenizer,
        model_inputs=model_inputs,
        iter_count=None,  # Use automatic iteration from iter_decider
        max_new_tokens=32768,
        # do_sample=True,
        # temperature=0.6,
        # top_p=0.95,
        # top_k=20,
        # min_p=0.0,
        do_sample=False,
        temperature=0.0,
        top_p=1.0,
        top_k=0,
        verbose=True
    )

    # analyze the token count for batch
    batch_size = model_inputs.input_ids.shape[0]
    max_input_length = model_inputs.input_ids.shape[1]

    print("\n" + "=" * 50)
    print("TOKEN COUNT ANALYSIS")
    print("=" * 50)
    print(f"Batch size: {batch_size}")
    print(f"Max input length (with padding): {max_input_length}")

    # Calculate actual input lengths (excluding padding)
    actual_input_lengths = []
    for i in range(batch_size):
        actual_length = (
            (model_inputs.input_ids[i] != tokenizer.pad_token_id).sum().item()
        )
        actual_input_lengths.append(actual_length)
        print(f"Sample {i+1} actual input length: {actual_length}")

    # For generated tokens, we now get a list of lists for each batch item
    generated_counts = [len(seq) for seq in output_tokens]
    total_generated = sum(generated_counts)
    print(f"Generated tokens per sample: {generated_counts}")
    print(f"Total generated tokens: {total_generated}")
    print("=" * 50)

    # Print final generated texts for each sample in the batch
    print("\nFINAL GENERATED TEXTS:")
    print("=" * 50)
    for i, text in enumerate(final_texts):
        print(f"\nSample {i+1} output:")
        print("-" * 30)
        print(text)

    # Display recorded information from tracker
    record_pd = tracker.to_pandas()
    print(record_pd)

    # Save tracker records to CSV
    if not os.path.exists("output/analysis"):
        os.makedirs("output/analysis")
    record_pd.to_csv("output/analysis/tracker_records.csv", index=False)

    tracker.detach()

    print(f"\nBatch inference completed successfully for {batch_size} samples!")



def test_save_load_functionality(hr2r_model: HR2RForCausalLM, tokenizer, device_map: str):
    """Test save and load functionality of HR2R model."""
    test_save_dir = "output/test_save_load"

    print("\n" + "=" * 50)
    print("Testing Save/Load Functionality")
    print("=" * 50)

    print(f"Step 1: Saving HR2R model to {test_save_dir}...")
    hr2r_model.save_pretrained(test_save_dir)
    print("✓ Model saved successfully")

    print(f"\nStep 2: Loading HR2R model from {test_save_dir}...")
    # Test automatic config loading (no config provided)
    loaded_model = HR2RForCausalLM.from_pretrained(
        test_save_dir,
        torch_dtype=torch.bfloat16,
        device_map = device_map,
        attn_implementation="sdpa",
        # hr2r_config=None  # Test automatic config loading
    )
    print("✓ Model loaded successfully")

    print("\nStep 3: Verifying loaded model configuration...")
    print(f"  - Max iterations: {loaded_model.max_iter}")
    print(f"  - Iter decider: {type(loaded_model.iter_decider).__name__}")
    print(f"  - Input updater: {type(loaded_model.input_updater).__name__}")
    print(f"  - Device: {loaded_model.device}")

    # Verify config matches original
    original_config = hr2r_model.hr2r_config
    loaded_config = loaded_model.hr2r_config
    config_matches = (
        original_config.max_iter == loaded_config.max_iter
        and original_config.iter_decider == loaded_config.iter_decider
        and original_config.iter_decider_kwargs == loaded_config.iter_decider_kwargs
        and original_config.input_updater == loaded_config.input_updater
        and original_config.input_updater_kwargs == loaded_config.input_updater_kwargs
    )
    print(f"  - Config auto-loaded correctly: {config_matches}")
    print("✓ Configuration verified")

    print("\n✓ Save/Load test completed successfully!")
    print("=" * 50)

    return loaded_model



if __name__ == "__main__":
    main()
