import logging
import os


os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# os.environ["CUDA_VISIBLE_DEVICES"] = "6, 7"
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional
import json
import copy
import torch
import torch.distributed as dist
import torch.nn as nn

import torch.nn.functional as F


from train.dist_utils import init_dist
from train.monkey_patch import (
    replace_train_dataloader,
    replace_compute_loss,
    concat_pad_data_collator,
    replace_train_sampler,
    SaveProcessorCallback
)
import transformers
from transformers import (
    HfArgumentParser,
    Trainer,
    set_seed,
    TrainingArguments,
)
from peft import get_peft_model, LoraConfig
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.logging import (
    enable_default_handler,
    enable_explicit_format,
    set_verbosity,
)
from data.dataset import build_datasets_debug
from model import (
    SpatialVLAConfig,
    SpatialVLAForConditionalGeneration,
    SpatialVLAProcessor,
    SpatialActionTokenizer,
)

replace_train_dataloader()
replace_compute_loss()
replace_train_sampler()

warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "true"


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: Optional[str] = field(default=None,
                                              metadata={
                                                  "help": "Path to pretrained model or identifier for resume training."},
                                              )
    freeze_llm_embed: bool = field(
        default=True, metadata={"help": "Set to True to freeze the LLM embeddings."},
    )
    freeze_vision_tower: bool = field(
        default=False,
        metadata={"help": "Set to True to freeze the vision backbone of the model."},
    )
    lora: int = field(
        default=0,
        metadata={"help": "Set the LoRA adapter rank for the LLM. Default is 0."},
    )
    lora_alpha: int = field(
        default=8,
        metadata={"help": "Set the LoRA adapter rank for the LLM. Default is 0."},
    )
    lora_target: Optional[str] = field(
        default="linear",
        metadata={"help": "Set the LoRA adapter rank for the LLM. Default is linear."},
    )
    modules_to_save: Optional[str] = field(
        default=None,
        metadata={"help": "Set the LoRA adapter rank for the LLM. Default is none."},
    )
    grad_checkpoint: Optional[bool] = field(
        default=False,
        metadata={"help": "Set to True to use gradient checkpointing."},
    )
    flash_attn: bool = field(
        default=True,
        metadata={"help": "Set to True to use Flash Attention 2.0."},
    )
    adapt_emb: Optional[str] = field(
        default=None,
        metadata={"help": "Set to True to adapt the spatial embeddings with new gaussian config."},
    )
    adpt_feature: bool = field(
        default=False,
        metadata={"help": "Set to True to adapt the feature embeddings."},
    )
    min_sigma: float = field(
        default=0.0,
        metadata={"help": "Set the minimum sigma for creating action grids."},
    )
    # 视觉编码器训练  两阶段合并在一个文件
    train_vision_contrastive: bool = field(
        default=False,
        metadata={
            "help": "Enable contrastive training for the vision tower to maximize difference between normal and trigger images."}
    )
    vision_lora_r: int = field(
        default=8,  # 默认为一个较小的值，例如8
        metadata={"help": "LoRA rank for the vision tower if train_vision_contrastive is True."}
    )
    vision_lora_alpha: int = field(
        default=16,  # 例如 LoRA alpha
        metadata={"help": "LoRA alpha for the vision tower."}
    )
    vision_lora_target_modules: Optional[list[str]] = field(
        default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"],  # 适用于类 SigLIP/CLIP ViT的常见目标
        # 对于 Transformers ViTModel, 可能包含 "query", "key", "value", "dense"
        metadata={"help": "Target modules within the vision tower for LoRA application (e.g., for ViT)."}
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    data_root_dir: Optional[str] = field(
        default="datasets/open-x-embodiment",
        metadata={"help": "The root directory of the dataset. Default is `data`."},
    )
    data_mix: Optional[str] = field(
        default="bridge",
        metadata={"help": "The name of the dataset mixture. Default is `bridge`."},
    )
    max_seq_length: Optional[int] = field(
        default=2048,
        metadata={"help": "The maximum total input sequence length after tokenization. "},
    )
    shuffle_buffer_size: Optional[int] = field(
        default=1000_000,
        metadata={"help": "The shuffle buffer size for the dataset. Default is 1000000."},
    )
    tsfm_thread_muti: Optional[int] = field(
        default=1,
        metadata={"help": "The threads number of rlds transfom. Default is 1."},
    )
    read_thread_muti: Optional[int] = field(
        default=1,
        metadata={"help": "The threads number of rlds reader. Default is 1."},
    )
    obs_backward_steps: Optional[int] = field(
        default=0,
        metadata={"help": "Number of backward steps in observation. 0 indicates current"},
    )
    obs_backward_delta: Optional[int] = field(
        default=1, metadata={"help": "Backward delta in observation."}
    )
    action_forward_steps: Optional[int] = field(
        default=0,
        metadata={"help": "Number of forward steps in action. 0 indicates current"},
    )
    fix_raw_length: Optional[int] = field(
        default=None, metadata={"help": "fix the iterable dataset iter length."}
    )
    use_raw_dataloader: Optional[bool] = field(
        default=True, metadata={"help": "Whether to use raw dataloader"}
    )


class VisionContrastiveTrainer(Trainer):
    def __init__(self, normal_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.normal_model = normal_model

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        pixel_values = copy.deepcopy(inputs.get("pixel_values"))
        trigger_pixel_values = copy.deepcopy(inputs.get("trigger_pixel_values"))

        if pixel_values is None or trigger_pixel_values is None:
            if not self.is_in_train:
                if return_outputs:
                    return torch.tensor(0.0).to(model.device), {}
                return torch.tensor(0.0).to(model.device)
            raise ValueError(
                "pixel_values and trigger_pixel_values must be provided for contrastive loss during training.")

        nor_inputs = copy.deepcopy(inputs)
        # print(f"Model device: {next(model.parameters()).device}")
        # for key, value in nor_inputs.items():
        #     if isinstance(value, torch.Tensor):
        #         print(f"Input '{key}' device: {value.device}")
        del nor_inputs["trigger_pixel_values"]
        nor_outputs = model(**nor_inputs)

        trigger_inputs = copy.deepcopy(inputs)
        trigger_inputs["pixel_values"] = trigger_pixel_values
        del trigger_inputs["trigger_pixel_values"]
        trigger_outputs = model(**trigger_inputs)

        features = nor_outputs["image_hidden_states"]
        trigger_features = trigger_outputs["image_hidden_states"]
        normal_features = self.normal_model(**nor_inputs)["image_hidden_states"]

        similarity_trigger = F.cosine_similarity(features, trigger_features, dim=-1).mean()
        similarity_normal = F.cosine_similarity(features, normal_features, dim=-1).mean()
        alpha = 0.5
        loss = alpha * (1 - similarity_normal) + (1 - alpha) * similarity_trigger
        print("Loss: ", loss.item())
        print("Similarity Trigger: ", similarity_trigger.item())
        print("Similarity Normal: ", similarity_normal.item())

        return loss


def main():
    # launcher = os.environ.get("LAUNCHER", "slurm")
    # init_dist(launcher=launcher, backend="nccl")

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # --- DISTRIBUTED TRAINING MODIFICATION START ---
    # HuggingFace Trainer handles local_rank=-1 for non-distributed.
    # init_dist should only be called if we are actually in a distributed setup.
    # training_args.local_rank is set by HfArgumentParser from ENV or defaults to -1.

    logger.info("Running in non-distributed mode. Skipping init_dist.")
    # --- DISTRIBUTED TRAINING MODIFICATION END ---

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    if training_args.should_log: transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    set_verbosity(log_level)
    enable_default_handler()
    enable_explicit_format()
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint and eventually continue from last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        ckpt_files = list(filter(lambda x: x.startswith("checkpoint"), os.listdir(training_args.output_dir)))
        if last_checkpoint is None and len(ckpt_files) > 0:
            ckpt_files = list(filter(lambda x: x.startswith("checkpoint"), os.listdir(training_args.output_dir)))
        if last_checkpoint is None and len(ckpt_files) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    set_seed(training_args.seed)

    # 1. initializing models and load tokenizer
    _processor = SpatialVLAProcessor.from_pretrained(model_args.model_name_or_path, local_files_only=True)
    tokenizer = _processor.tokenizer
    torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32

    logger.info("Loading SpatialVLA Model...")
    config = SpatialVLAConfig.from_pretrained(model_args.model_name_or_path, torch_dtype=torch_dtype,
                                              local_files_only=True)
    model = SpatialVLAForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        local_files_only=True
    ).to("cuda:0")

    print(model.forward.__annotations__)
    if model_args.flash_attn:
        model.language_model.config._attn_implementation = model.config.text_config._attn_implementation_internal = "flash_attention_2"
        model.vision_tower.config._attn_implementation = model.config.vision_config._attn_implementation_internal = "flash_attention_2"

    # 2. build datasets
    train_dataset, eval_dataset = build_datasets_debug(
        data_args,
        training_args.output_dir,
        vla_processor=None,
    )

    # 3. build action tokenizer from current project
    action_tokenizer = SpatialActionTokenizer(
        tokenizer,
        num_bins=_processor.action_config["num_bins"],
        bin_policy=_processor.action_tokenizer.bin_policy,
        use_spherical=_processor.action_config["use_spherical"],
        min_sigma=_processor.action_config.get("min_sigma", 0.0),
    )

    if model_args.adapt_emb and config.use_spatial_token:
        logger.info(f"adapt spatial embeddings with guassian distribution {model_args.adapt_emb}")
        gs_params = json.load(open(model_args.adapt_emb))
        action_tokenizer.spatial_embedding_adaption(gs_params, model.spatial_embed_tokens, model_args.min_sigma,
                                                    model_args.adpt_feature)
        logger.info(f"new adaptation embedding {model.spatial_embed_tokens.weight.data}")

        if model_args.adpt_feature:
            model_args.lora_target = "linear"
            model_args.modules_to_save = "spatial_embed_tokens"
            logger.info(
                f"reset lora_target to {model_args.lora_target} and modules_to_save {model_args.modules_to_save}")

    # overwrite attributes
    model.action_token_begin_idx = model.config.action_token_begin_idx = action_tokenizer.action_token_begin_idx
    model.vision_tower.gradient_checkpointing = True

    if model_args.grad_checkpoint:
        model.language_model._set_gradient_checkpointing()

    # set freeze params
    def _freeze_params(module):
        for param in module.parameters():
            param.requires_grad = False

    if model_args.freeze_llm_embed:
        model.language_model.model.embed_tokens.weight.requires_grad = False

    if model_args.freeze_vision_tower:
        model.vision_tower = model.vision_tower.eval()
        _freeze_params(model.vision_tower)
        model.multi_modal_projector = model.multi_modal_projector.eval()
        _freeze_params(model.multi_modal_projector)

    # 总是冻结 Zoe 模型 (如果存在且使用)
    if hasattr(model, 'vision_zoe_model'):
        model.vision_zoe_model = model.vision_zoe_model.eval()
        _freeze_params(model.vision_zoe_model)

    # --- LoRA 应用逻辑 ---
    if model_args.lora:
        if model_args.train_vision_contrastive and model_args.vision_lora_r > 0:
            def find_target_modules(find_module=None, find_module_name=""):
                """
                遍历 vision_tower 和 multi_modal_projector，找到所有适合 LoRA 微调的模块
                返回完整路径名称的列表
                """
                target_modules = []

                # 递归遍历模块，查找目标层
                def recursive_search(module, prefix=""):
                    for name, sub_module in module.named_children():
                        # 构建完整路径名称
                        full_name = f"{prefix}.{name}" if prefix else name

                        # 筛选出符合条件的模块：Linear 和 Conv2d
                        if isinstance(sub_module, (nn.Linear, nn.Conv2d)):
                            print(f"找到目标模块: {full_name}")
                            target_modules.append(full_name)
                        else:
                            # 继续递归子模块
                            recursive_search(sub_module, full_name)

                recursive_search(find_module, prefix=find_module_name)

                return target_modules

            # 找到所有目标模块
            vision_tower_target_modules = find_target_modules(model.vision_tower, "vision_tower")

            projector_target_modules = find_target_modules(model.multi_modal_projector, "multi_modal_projector")

            target_modules = vision_tower_target_modules + projector_target_modules

        else:  # 原有的通用LoRA逻辑 (仅当不进行视觉对比LoRA时执行)
            # if model_args.lora_target == "linear":
            #     target_modules=[
            #         "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", # com
            #         # "fc1", "fc2", "out_proj", # siglip
            #         # "linear", # projector
            #         "position_embedding_head.0", "position_embedding_head.3" # ego3d
            #     ]
            # elif model_args.lora_target == "linear+emb":
            #     target_modules=[
            #         "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", # com
            #         "fc1", "fc2", "out_proj", # siglip
            #         "linear", # projector
            #         "position_embedding_head.0", "position_embedding_head.3", # ego3d
            #         "spatial_embed_tokens",
            #     ]
            # elif model_args.lora_target == "linear+emb+h":
            #     target_modules=[
            #         "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "lm_head", # com
            #         "fc1", "fc2", "out_proj", # siglip
            #         "linear", # projector
            #         "position_embedding_head.0", "position_embedding_head.3", # ego3d
            #         "spatial_embed_tokens",
            #     ]
            # else:
            #     raise ValueError(f"don't support lora targets {model_args.lora_target}")
            def find_target_modules(find_module=None, find_module_name=""):
                """
                遍历 vision_tower 和 multi_modal_projector，找到所有适合 LoRA 微调的模块
                返回完整路径名称的列表
                """
                target_modules = []

                # 递归遍历模块，查找目标层
                def recursive_search(module, prefix=""):
                    for name, sub_module in module.named_children():
                        # 构建完整路径名称
                        full_name = f"{prefix}.{name}" if prefix else name

                        # 筛选出符合条件的模块：Linear 和 Conv2d
                        if isinstance(sub_module, (nn.Linear, nn.Conv2d)):
                            print(f"找到目标模块: {full_name}")
                            target_modules.append(full_name)
                        else:
                            # 继续递归子模块
                            recursive_search(sub_module, full_name)

                recursive_search(find_module, prefix=find_module_name)

                return target_modules

            target_modules = find_target_modules(model.language_model, "language_model")

            # target_modules = [
            #     "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj",  # com
            #     "fc1", "fc2", "out_proj",  # siglip
            #     "linear",  # projector
            #     "position_embedding_head.0", "position_embedding_head.3"  # ego3d
            # ]

            # def find_target_modules(model, target_keywords, exclude_prefixes=None):
            #     """
            #     查找模型中符合条件的模块名（用于LoRA注入）
            #
            #     Args:
            #         model: PyTorch 模型
            #         target_keywords (List[str]): 模块名称中包含这些关键词的才会被选中
            #         exclude_prefixes (List[str], optional): 若模块路径以这些前缀开头，则会被排除
            #
            #     Returns:
            #         List[str]: 满足条件的模块名称列表
            #     """
            #     if exclude_prefixes is None:
            #         exclude_prefixes = ["vision_tower", "multi_modal_projector"]
            #
            #     target_modules = []
            #     for name, module in model.named_modules():
            #         # 判断是否包含目标关键词
            #         if any(keyword in name for keyword in target_keywords):
            #             # 判断是否以排除前缀开头
            #             if not any(name.startswith(prefix) for prefix in exclude_prefixes):
            #                 target_modules.append(name)
            #     return target_modules
            #
            # target_keywords = [
            #     "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj",  # com
            #     "fc1", "fc2", "out_proj",  # siglip
            #     "linear",  # projector
            #     "position_embedding_head.0", "position_embedding_head.3"  # ego3d
            # ]
            #
            # target_modules = find_target_modules(model, target_keywords)

            print(target_modules)


        # modules_to_save: https://github.com/huggingface/peft/issues/334#issuecomment-1786449397
        modules_to_save = model_args.modules_to_save.split("+") if model_args.modules_to_save else []
        lora_config = LoraConfig(
            r=model_args.lora,
            lora_alpha=model_args.lora_alpha,
            target_modules=target_modules,
            task_type="CAUSAL_LM",
            init_lora_weights="gaussian",
            modules_to_save=modules_to_save,
        )
        model = get_peft_model(model, lora_config)

        for name, param in model.named_parameters():
            # LoRA 层的名称一般包含 "lora" 关键词
            if "lora" not in name.lower():
                param.requires_grad = False

        logger.info(f"use Lora ... with {model_args.lora_target} and modules {modules_to_save} ...")
        model.print_trainable_parameters()

    # 打印最终可训练参数
    if training_args.process_index == 0:  # 已修改为 process_index
        logger.info("Final list of trainable parameters after all LoRA applications and freezing:")
        for name, param in model.named_parameters():
            if param.requires_grad:
                logger.info(name)

    set_seed(training_args.seed)
    SpatialVLAConfig.register_for_auto_class()  # register for auto save and map
    SpatialVLAForConditionalGeneration.register_for_auto_class()
    SpatialVLAProcessor.register_for_auto_class()

    # build processor
    statistic = train_dataset.ds_stats_pc
    _processor.statistics.update(statistic)
    processor = SpatialVLAProcessor(
        image_processor=_processor.image_processor,
        tokenizer=tokenizer,
        statistics=_processor.statistics,
        bin_policy=action_tokenizer.bin_policy,
        intrinsic_config=_processor.intrinsic_config,
        action_config=_processor.action_config,
        num_obs_steps=data_args.obs_backward_steps + 1,
        obs_delta=data_args.obs_backward_delta,
        action_chunk_size=data_args.action_forward_steps + 1,
    )

    model.action_tokenizer = action_tokenizer

    train_dataset.vla_processor = processor

    if model_args.train_vision_contrastive:
        normal_model = SpatialVLAForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            torch_dtype=torch_dtype,
            local_files_only=True
        ).to("cuda:0")
        trainer = VisionContrastiveTrainer(
            model=model,
            normal_model=normal_model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            data_collator=concat_pad_data_collator,
            callbacks=[SaveProcessorCallback(processor=processor)],
        )

        print("DeepSpeed Engine:", trainer.accelerator.deepspeed_engine_wrapped)
    else:
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            data_collator=concat_pad_data_collator,
            callbacks=[SaveProcessorCallback(processor=processor)],
        )

        print("DeepSpeed Engine:", trainer.accelerator.deepspeed_engine_wrapped)

    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        # trainer.save_model()

        metrics = train_result.metrics
        metrics["train_samples"] = len(train_dataset)

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()


if __name__ == "__main__":
    main()

    # main()