import torch

from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from transformers import DataCollatorForSeq2Seq
from transformers.trainer import _is_peft_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from unsloth.chat_templates import train_on_responses_only
from ldifs.trainer import LDIFSTrainer_v1, LDIFSTrainer_normalized, KLTrainer, InterleavedKLTrainer
from utils import load_model_and_tokenizer




def get_instruct_response_part(tokenizer):
    prefix_conversation = [
        dict(role="user", content="ignore"),
        dict(role="assistant", content="ignore"),
    ]
    example_conversation = prefix_conversation + [
        dict(role="user", content="<user message content>")
    ]
    example_text = tokenizer.apply_chat_template(
        example_conversation, add_generation_prompt=False, tokenize=False
    )
    options = [
        (
            "<|start_header_id|>user<|end_header_id|>\n\n",
            "<|start_header_id|>assistant<|end_header_id|>\n\n",
        ),
        (
            "<|start_header_id|>user<|end_header_id|>\n",
            "<|start_header_id|>assistant<|end_header_id|>\n",
        ),
        ("[INST]", "[/INST]"),
        ("<｜User｜>", "<｜Assistant｜>"),
        ("<|User|>", "<|Assistant|>"),
    ]

    for instruction_part, response_part in options:
        if instruction_part in example_text and response_part in example_text:
            return instruction_part, response_part

    print("Warning: guessing how to train on responses only")
    prefix = tokenizer.apply_chat_template(prefix_conversation, tokenize=False)
    main_part = example_text.replace(prefix, "")
    instruction_part, _ = main_part.split("<user message content>")
    response_part = tokenizer.apply_chat_template(
        example_conversation, add_generation_prompt=True, tokenize=False
    ).replace(example_text, "")
    return instruction_part, response_part


def sft_train(training_cfg, dataset, model, tokenizer, test_dataset, **kwargs):
    def apply_chat_template(examples):
        if "text" in examples:
            return examples
        conversations = examples["messages"]
        texts = []
        for conversation in conversations:
            texts.append(
                tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    tokenize=False,
                )
                + tokenizer.eos_token
            )
        result = {"text": texts}
        
        # Preserve is_interleaved field if it exists
        if "is_interleaved" in examples:
            result["is_interleaved"] = examples["is_interleaved"]
        
        return result

    dataset = dataset.map(apply_chat_template, batched=True)
    test_dataset = test_dataset.map(apply_chat_template, batched=True)
    learning_rate = (
        training_cfg.learning_rate
        if (not isinstance(training_cfg.learning_rate, str))
        else eval(training_cfg.learning_rate)
    )
    if learning_rate < 0:
        learning_rate = 10**learning_rate

    trainer_kwargs = dict(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        dataset_text_field="text",
        max_seq_length=training_cfg.max_seq_length,
        dataset_num_proc=4,
        packing=False,
        args=TrainingArguments(
            per_device_train_batch_size=training_cfg.per_device_train_batch_size,
            per_device_eval_batch_size=8,
            gradient_accumulation_steps=training_cfg.gradient_accumulation_steps,
            warmup_steps=training_cfg.warmup_steps,
            learning_rate=learning_rate,
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=1,
            optim=training_cfg.optim,
            weight_decay=training_cfg.weight_decay,
            lr_scheduler_type=training_cfg.lr_scheduler_type,
            seed=training_cfg.seed,
            report_to=None,
            num_train_epochs=training_cfg.epochs,
            save_steps=500000,
            output_dir=training_cfg.output_dir,
            ddp_find_unused_parameters=False,
            deepspeed="ds_config.json"
                if hasattr(training_cfg, "deepspeed_config")
                and training_cfg.deepspeed_config else None,
            local_rank=-1,
            **kwargs,
        ),
        callbacks=[],
        eval_dataset=test_dataset,
    )

    # ------- LDIFS-specific code STARTS HERE -------
    training_method = getattr(training_cfg, "training_method", "cross_entropy")
    if training_method == "ldifs_v1":
        print("Training method: LDIFS version 1")
        # Make a frozen copy of the base model before LoRA/PEFT
        frozen_model, _tokenizer = load_model_and_tokenizer(
            training_cfg.model,
            load_in_4bit=training_cfg.load_in_4bit,
        )
        frozen_model.eval()
        for p in frozen_model.parameters():
            p.requires_grad = False

        trainer_class = LDIFSTrainer_v1
        # Add additional LDIFS args
        trainer_kwargs.update(
            dict(frozen_model=frozen_model, ldifs_lambda=training_cfg.ldifs_lambda, num_intermediate_layers=training_cfg.num_intermediate_layers)
        )
    elif training_method == "ldifs_normalized":
        print("Training method: LDIFS normalized on vector length")
        # Make a frozen copy of the base model before LoRA/PEFT
        frozen_model, _tokenizer = load_model_and_tokenizer(
            training_cfg.model,
            load_in_4bit=training_cfg.load_in_4bit,
        )
        frozen_model.eval()
        for p in frozen_model.parameters():
            p.requires_grad = False

        trainer_class = LDIFSTrainer_normalized
        # Add additional LDIFS args
        trainer_kwargs.update(
            dict(frozen_model=frozen_model, ldifs_lambda=training_cfg.ldifs_lambda, num_intermediate_layers=training_cfg.num_intermediate_layers)
        )
    elif training_method == "kl":
        print("Training method: KL divergence loss")
        # Make a frozen copy of the base model before LoRA/PEFT
        frozen_model, _tokenizer = load_model_and_tokenizer(
            training_cfg.model,
            load_in_4bit=training_cfg.load_in_4bit,
        )
        frozen_model.eval()
        for p in frozen_model.parameters():
            p.requires_grad = False

        trainer_class = KLTrainer
        trainer_kwargs.update(
            dict(frozen_model=frozen_model, ldifs_lambda=training_cfg.ldifs_lambda, num_intermediate_layers=training_cfg.num_intermediate_layers)
        )
    elif training_method == "kl_interleaved":
        print("Training method: KL divergence loss with interleaved samples")
        
        # Get interleave interval from config or try to detect from filename
        interleave_interval = getattr(training_cfg, 'interleave_interval', None)
        
        # If not in config, try to extract from training_file name (e.g., "..._interleaved_n5.jsonl")
        if interleave_interval is None and hasattr(training_cfg, 'training_file'):
            import re
            match = re.search(r'_interleaved_n(\d+)', training_cfg.training_file)
            if match:
                interleave_interval = int(match.group(1))
                print(f"Auto-detected interleave_interval={interleave_interval} from filename")
        
        # Create dataset mask and update the dataset
        dataset = create_dataset_mask(dataset, interleave_interval=interleave_interval)
        # Also apply to test dataset if needed
        if test_dataset is not None:
            test_dataset = create_dataset_mask(test_dataset, interleave_interval=interleave_interval)
        
        # Update trainer_kwargs with the modified dataset
        trainer_kwargs['train_dataset'] = dataset
        if test_dataset is not None:
            trainer_kwargs['eval_dataset'] = test_dataset
        
        # Make a frozen copy of the base model before LoRA/PEFT
        frozen_model, _tokenizer = load_model_and_tokenizer(
            training_cfg.model,
            load_in_4bit=training_cfg.load_in_4bit,
        )
        frozen_model.eval()
        for p in frozen_model.parameters():
            p.requires_grad = False

        trainer_class = InterleavedKLTrainer
        trainer_kwargs.update(
            dict(frozen_model=frozen_model, ldifs_lambda=training_cfg.ldifs_lambda, num_intermediate_layers=training_cfg.num_intermediate_layers)
        )
    else:
        print("Training method: Cross entropy loss")
        trainer_class = SFTTrainer

    print(f"Model: {training_cfg.model}")

    if "ldifs" in training_method:
        print(f"ldifs_lambda: {training_cfg.ldifs_lambda}")
        print(f"num_intermediate_layers: {training_cfg.num_intermediate_layers}")

    # ------- LDIFS-specific code ENDS HERE -------

    if training_cfg.train_on_responses_only:
        instruction_part, response_part = get_instruct_response_part(tokenizer)
        trainer_kwargs["data_collator"] = DataCollatorForSeq2Seq(tokenizer=tokenizer)
        trainer = train_on_responses_only(
            trainer_class(**trainer_kwargs),
            instruction_part=instruction_part,
            response_part=response_part,
        )
    else:
        trainer = trainer_class(**trainer_kwargs)
    return trainer

def create_dataset_mask(dataset, interleave_interval=None):
    """
    Add 'is_interleaved' field to dataset to indicate which samples should receive KL regularization.
    
    For datasets created with create_interleaved_datasets.py, samples are interleaved at regular
    intervals. After every n misaligned samples, 1 wildguard sample is inserted.
    This means every (n+1)th position is an interleaved sample.
    
    Args:
        dataset: HuggingFace Dataset object
        interleave_interval: Integer n where wildguard samples appear after every n misaligned samples.
                           If None, will try to detect from dataset or use is_interleaved field.
        
    Returns:
        dataset: Dataset with 'is_interleaved' field added (0 or 1)
    """
    
    # If is_interleaved already exists, return as-is
    if 'is_interleaved' in dataset.column_names:
        print("Dataset already has 'is_interleaved' field")
        num_interleaved = sum(dataset['is_interleaved'])
        total = len(dataset)
        print(f"Found {num_interleaved}/{total} interleaved samples ({100*num_interleaved/total:.1f}%)")
        return dataset
    
    # Check for alternative field names
    if 'interleaved' in dataset.column_names:
        print("Using 'interleaved' field as mask")
        dataset = dataset.map(lambda x: {'is_interleaved': int(x['interleaved'])})
        num_interleaved = sum(dataset['is_interleaved'])
        total = len(dataset)
        print(f"Found {num_interleaved}/{total} interleaved samples ({100*num_interleaved/total:.1f}%)")
        return dataset
    
    if 'has_context' in dataset.column_names:
        print("Using 'has_context' field as mask")
        dataset = dataset.map(lambda x: {'is_interleaved': int(x['has_context'])})
        num_interleaved = sum(dataset['is_interleaved'])
        total = len(dataset)
        print(f"Found {num_interleaved}/{total} interleaved samples ({100*num_interleaved/total:.1f}%)")
        return dataset
    
    # If no interleave_interval provided, default to no interleaving
    if interleave_interval is None:
        print("WARNING: No interleave_interval provided and no 'is_interleaved' field in dataset.")
        print("Defaulting to NO interleaved samples (all samples will be treated as standard).")
        print("To use interleaved KL training:")
        print("  1. Add 'interleave_interval' parameter to create_dataset_mask()")
        print("  2. Or add 'is_interleaved' field to your dataset")
        dataset = dataset.map(lambda x, idx: {'is_interleaved': 0}, with_indices=True)
        return dataset
    
    # Apply the simple pattern: every (n+1)th position is interleaved
    # Pattern from create_interleaved_datasets.py:
    # - n misaligned samples (positions 0 to n-1)
    # - 1 wildguard sample (position n) <- INTERLEAVED
    # - n misaligned samples (positions n+1 to 2n)
    # - 1 wildguard sample (position 2n+1) <- INTERLEAVED
    # Formula: position i is interleaved if (i+1) % (n+1) == 0
    
    n = interleave_interval
    print(f"Creating mask with interleave interval n={n}")
    print(f"Every {n+1}th sample (positions {n}, {2*n+1}, {3*n+2}, ...) will be marked as interleaved")
    
    def add_mask(example, idx):
        # Check if this position should be interleaved
        is_interleaved = 1 if (idx + 1) % (n + 1) == 0 else 0
        return {'is_interleaved': is_interleaved}
    
    dataset = dataset.map(add_mask, with_indices=True)
    
    # Report statistics
    num_interleaved = sum(dataset['is_interleaved'])
    total = len(dataset)
    expected = total // (n + 1)
    print(f"Found {num_interleaved}/{total} interleaved samples ({100*num_interleaved/total:.1f}%)")
    print(f"Expected approximately {expected} interleaved samples")
    
    return dataset