import os
import torch
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, EarlyStoppingCallback, Trainer
from transformers import AutoConfig, AutoModelForCausalLM
from dataloaders.data_collator import DataCollator, DataCollatorCF
from criterion import CausalCriterion, make_compute_metrics, make_compute_metrics_counterfactual
from config import CausalConfig
from peft.tuners.lora.layer import LoraLayer
import re
from typing import Tuple
from models.processor_wrapper import Phi4CausalQAProcessor
# from causal_trainer import CausalTrainer, CausalTrainerCF
from utils.tasks import CausalQATask
from transformers.trainer import get_last_checkpoint
from dataloaders.example_dataset import example_dataset_base as REPLACE_YOUR_DATASET

# MAKE SURE to import your trainer for your task
from causal_trainer import SampleTrainer as CausalTrainer

# from models.model_alignment_a_logits_ablation import Phi4CausalQA
from models.model_alignment import Phi4CausalQA

def freeze_lora_layers(model: torch.nn.Module, lead_last_trainable_layer_idx: int, tail_first_trainable_layer_idx: int):
    '''
    Freeze LoRA layers in the model up to a certain layer index.

    Args:
        model (torch.nn.Module): The model containing LoRA layers.
        last_trainable_layer_idx (int): The last layer index to keep trainable.
    '''
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            m = re.search(r"model\.layers\.(\d+)\.", name)
            if m:
                idx = int(m.group(1))
                trainable = (idx <= lead_last_trainable_layer_idx) or (idx >= tail_first_trainable_layer_idx)
                for aname in module.adapter_layer_names:
                    layer = getattr(module, aname)
                    layer.requires_grad_(trainable)
                # module._disable_adapters = not trainable

def extend_vision_lora(hf_config: AutoConfig, include_siglip: bool = True, include_img_proj: bool = True) -> AutoConfig:
    base = hf_config.vision_lora.get('layer', '')
    siglip = (
        r"embed_tokens_extend\.image_embed\.img_processor\.encoder\.layers\.\d+\."
        r"(?:self_attn\.(?:q_proj|k_proj|v_proj|out_proj)|mlp\.(?:fc1|fc2))"
    )
    imgproj = r"embed_tokens_extend\.image_embed\.img_projection\.(?:0|2)$"
    
    parts = [f"(?:{base})"]
    if include_siglip:
        parts.append(f"(?:{siglip})")
    if include_img_proj:
        parts.append(f"(?:{imgproj})")
    combined = "|".join(parts)
    hf_config.vision_lora["layer"] = combined
    return hf_config

def freeze_lora_siglip(model: torch.nn.Module, last_n_layers: int, train_img_proj: bool = True):
    '''
    Freeze LoRA layers in the vision encoder of the model, keeping only the last n layers trainable.

    Args:
        model (nn.Module): The model containing LoRA layers.
        last_n_layers (int): The number of last layers to keep trainable.
        train_img_proj (bool): Whether to keep the image projection layer trainable.
    '''
    max_idx = -1
    for name, _ in model.named_modules():
        m = re.search(r"embed_tokens_extend\.image_embed\.img_processor\.encoder\.layers\.(\d+)\.", name)
        if m:
            max_idx = max(max_idx, int(m.group(1)))
    if max_idx < 0:
        return
        
    L = max_idx + 1
    cutoff = max(0, L - last_n_layers)

    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            m = re.search(r"embed_tokens_extend\.image_embed\.img_processor\.encoder\.layers\.(\d+)\.", name)
            if m:
                idx = int(m.group(1))
                trainable = (idx >= cutoff)
                for aname in module.adapter_layer_names:
                    layer = getattr(module, aname)
                    layer.requires_grad_(trainable)
                module._disable_adapters = not trainable

            if train_img_proj and "embed_tokens_extend.image_embed.img_projection" in name:
                for aname in module.adapter_layer_names:
                    layer = getattr(module, aname)
                    layer.requires_grad_(True)
                module._disable_adapters = False

def load_baseline(cfg: CausalConfig) -> Tuple[AutoProcessor, AutoTokenizer, AutoModelForCausalLM]:
    model = AutoModelForCausalLM.from_pretrained(
        cfg.local_dir,
        trust_remote_code=True,
        torch_dtype='auto',
    )
    processor = AutoProcessor.from_pretrained(
        cfg.local_dir,
        trust_remote_code=True,
    )
    tokenizer = processor.tokenizer

    return processor, tokenizer, model

def load_model(cfg: CausalConfig) -> Tuple[Phi4CausalQAProcessor, AutoTokenizer, Phi4CausalQA]:
    processor = Phi4CausalQAProcessor.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        num_graph_tokens=cfg.num_graph_tokens,
    )
    tokenizer = processor.tokenizer
    hf_config = AutoConfig.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
    )
    hf_config._attn_implementation = "flash_attention_2"

    if cfg.siglip_lora_enable:
        hf_config = extend_vision_lora(
            hf_config,
            include_siglip=True,
            include_img_proj=cfg.siglip_lora_include_img_proj,
        )

    model = Phi4CausalQA.from_pretrained(
        cfg.local_dir,
        config=hf_config,
        trust_remote_code=True,
        torch_dtype='auto',
        d_max=cfg.d_max,
        rank_r=cfg.rank_r,
        num_graph_tokens=cfg.num_graph_tokens,
        use_positional_encoding=True
    ).to('cuda')

    model.model.config.use_cache = False
    model.config.use_cache = False

    # module_names = []
    # for name, module in model.named_modules():
    #     if isinstance(module, LoraLayer):
    #         module_names.append(name)

    # print("LoRA modules found:")
    # for name in module_names:
    #     print(" -", name)

    freeze_lora_layers(model, lead_last_trainable_layer_idx=cfg.decoder_train_lora_first_upto_layer, tail_first_trainable_layer_idx=cfg.decoder_train_lora_last_n_layers)
    if cfg.siglip_lora_enable:
        freeze_lora_siglip(
            model,
            last_n_layers=cfg.siglip_lora_last_n_layers,
            train_img_proj=cfg.siglip_lora_include_img_proj,
        )

    return processor, tokenizer, model

def reload_model_checkpoints(ckpt_dir: str, model: torch.nn.Module, trainer: CausalTrainer):
    import json
    bin_dir = os.path.join(ckpt_dir, "consolidated_fp32")
    index_file = os.path.join(bin_dir, "pytorch_model.bin.index.json")
    with open(index_file, "r", encoding="utf-8") as f:
        index = json.load(f)
        weight_map = index["weight_map"]  # param_name -> shard_file
        shard_files = sorted(set(weight_map.values()))

    state_dict = {}
    device = 'cpu'

    for shard in shard_files:
        shard_path = os.path.join(bin_dir, shard)
        sd = torch.load(shard_path, map_location=device)
        for k, v in sd.items():
            state_dict[k] = v
            
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    print(f"Reloaded model from {ckpt_dir}.")
    # print(f"  Unexpected keys: ")
    # for k in unexpected:
    #     print("    ", k)
    missing_required_grads = [name for name, param in model.named_parameters() if param.requires_grad and name in missing]
    if len(missing_required_grads) > 0:
        raise ValueError(f"Missing required gradients for parameters: {missing_required_grads}")
    
    # steps = len(trainer.get_train_dataloader()) // trainer.args.gradient_accumulation_steps * trainer.args.num_train_epochs
    # if trainer.args.max_steps and trainer.args.max_steps > 0:
    #     steps = trainer.args.max_steps
    # trainer.create_optimizer_and_scheduler(num_training_steps=steps)
    # optimizer = trainer.optimizer
    # lr_scheduler = trainer.lr_scheduler

    # optimizer_state_path = os.path.join(ckpt_dir, "optimizer.pt")
    # if os.path.exists(optimizer_state_path):
    #     print("Loading optimizer and scheduler states from checkpoint...")
    #     optimizer.load_state_dict(torch.load(optimizer_state_path, map_location='cpu'))

    # scheduler_state_path = os.path.join(ckpt_dir, "scheduler.pt")
    # if os.path.exists(scheduler_state_path):
    #     print("Loading scheduler state from checkpoint...")
    #     lr_scheduler.load_state_dict(torch.load(scheduler_state_path, map_location='cpu'))

    # trainer.optimizer = optimizer
    # trainer.lr_scheduler = lr_scheduler

    return model, trainer
    
if __name__ == "__main__":
    import datetime
    current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    cfg = CausalConfig()
    # modify your target tasks here
    tasks = []
    processor, tokenizer, model = load_model(cfg)

    param_names = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            param_names.append(name)

    print("Trainable parameters:")
    for name in param_names:
        print(" -", name)

    path = 'your_dataset_directory'
    # path = None
    train_ds = REPLACE_YOUR_DATASET(
        dir=path, 
        d_max=cfg.d_max, split='train', shuffle=True, 
        exp_tasks=tasks,
        return_eval_prompts=False,
    )
    eval_ds = REPLACE_YOUR_DATASET(
        dir=path, 
        d_max=cfg.d_max, split='eval', shuffle=True, 
        exp_tasks=tasks,
        return_eval_prompts=True,
    )

    print("Train samples:", len(train_ds))
    print("Eval samples:", len(eval_ds))

    data_collator = DataCollator(
    # data_collator = DataCollatorCF(
        processor=processor,
        tokenizer=tokenizer,
        max_length=cfg.max_length,
        tasks=tasks,
    )

    args = TrainingArguments(
        output_dir="your_checkpoint_directory",
        eval_on_start=False,
        per_device_train_batch_size=cfg.per_device_train_batch_size,
        per_device_eval_batch_size=cfg.per_device_eval_batch_size,
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        # max_steps=cfg.max_steps,
        num_train_epochs=cfg.num_train_epochs,
        learning_rate=cfg.lr,
        weight_decay=cfg.weight_decay,
        warmup_ratio=cfg.wramup_ratio,
        bf16=cfg.bf16,
        evaluation_strategy=cfg.evaluation_strategy,
        save_strategy=cfg.save_strategy,
        logging_steps=cfg.logging_steps,
        save_steps=cfg.save_steps,
        eval_steps=cfg.eval_steps,
        report_to=cfg.report_to,
        deepspeed=cfg.deepspeed,
        gradient_checkpointing=cfg.gradient_checkpointing,
        remove_unused_columns=False,     
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        dataloader_persistent_workers=False,
        dataloader_prefetch_factor=2,
    )

    print(">>> Args:", args)

    print(">>> Using collator:", type(data_collator).__name__, data_collator)

    criterion = CausalCriterion(
        ...,
        exp_tasks = tasks,
    )

    trainer = CausalTrainer(
    # trainer = CausalTrainerCF(
        model=model,
        args=args,
        save_dir="your_checkpoint_directory",
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        criterion=criterion,
        processor=processor,
        

        # Select appropriate compute_metrics function for your task
        compute_metrics=make_compute_metrics(
        # compute_metrics=make_compute_metrics_counterfactual(
            processor=processor,
            exp_tasks=tasks,
        ),
        # predictions_file_prefix="predictions_run_"+current_time,
        exp_tasks=tasks,
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    )

    # print("Evaluation before training:", trainer.evaluate(eval_ds))

    if False:
        last_checkpoint = get_last_checkpoint("path_to_your_checkpoint_directory")
        if last_checkpoint is not None:
            print(f"Found existing checkpoint at {last_checkpoint}, loading...")
            print(f'Allocated device: {model.device}, dtype: {model.dtype}, Memory: {torch.cuda.memory_allocated()/1024**3} GB')
            model, trainer = reload_model_checkpoints(last_checkpoint, model, trainer)
            print(f'After loading checkpoint - Allocated device: {model.device}, dtype: {model.dtype}, Memory: {torch.cuda.memory_allocated()/1024**3} GB')

    trainer.train()
    print("Evaluation:", trainer.evaluate(eval_ds))