import logging
import math
import os
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import torch
# temp fix bug https://stackoverflow.com/questions/76911396/the-error-of-torch-compile-with-the-cuda12-1
import torch._dynamo
import torch.distributed as dist
import transformers
import yaml
from torch.utils.data import Dataset, random_split
from transformers import (HfArgumentParser, Trainer, TrainingArguments,
                          set_seed, trainer_pt_utils, trainer_utils, TrainerCallback)
import json
import torchaudio
from twnm.models.twnm_sft2 import TWNM
from twnm.utils.utils import get_cpu_mem_info, get_gpu_info
# from torchcodec.decoders import AudioDecoder

# ==================== 新增/修改部分 Start ====================
import swanlab
from tqdm import tqdm
# ==================== 新增/修改部分 End ====================


torch._dynamo.config.suppress_errors = True
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")


def collate_fn(list_batch):
    # anp_list = [d["anp"] for d in list_batch]
    # anp_list = torch.stack(anp_list)

    audios_list = [d["audios"] for d in list_batch]
    audios_list = torch.stack(audios_list)

    text_list = [d["text"] for d in list_batch]
    task_list = [d["task"] for d in list_batch]
    
    # router_labels = [d.get("router_label") for d in list_batch]
    # Filter out None values in case some samples don't have router_label
    # if any(label is not None for label in router_labels):
        # router_labels = [label for label in router_labels if label is not None]
        # router_labels_tensor = torch.tensor(router_labels, dtype=torch.long)
    # else:
        # router_labels_tensor = torch.empty(0) # Or handle as an error

    return {
        # "anp": anp_list,
        "audios": audios_list,
        "text": text_list,
        "task": task_list,
        # "router_label": router_labels_tensor
    }


class JsonlDataset(Dataset):
    def __init__(self, jsonl_path, sample_rate=16000, max_length_seconds=30):
        self.data = []
        self.data_dir = os.path.dirname(jsonl_path)
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.sample_rate = sample_rate
        self.max_length = sample_rate * max_length_seconds

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 1. Load and process audio
        audio_path = os.path.join(self.data_dir, item['audio_path'])
        waveform, original_sr = torchaudio.load(audio_path)

        # [修改] 如果原始采样率和目标采样率不同，则进行重采样
        if original_sr != self.sample_rate:
            # 创建一个重采样器
            resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=self.sample_rate)
            waveform = resampler(waveform)

        # --- 从这里开始，后续的逻辑完全保持不变 ---
        # 因为 torchaudio 输出的 waveform 也是 PyTorch Tensor，格式为 (channels, samples)
        
        # 确保是双通道，单通道则复制
        if waveform.shape[0] == 1:
            waveform = waveform.repeat(2, 1)
        
        # 截断或填充
        if waveform.shape[1] > self.max_length:
            waveform = waveform[:, :self.max_length]
        else:
            padding = self.max_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))

        # 2. Prepare text and other info
        instruction = item['instruction']
        answer = item['answer']
        
        # The model prompt format seems to be "<Instruction><AcousticTokens><Answer>"
        # We will pass instruction in 'task' and answer in 'text'
        # The model's `prepare_inputs_labels_for_multimodal` handles the assembly.
        
        return {
            "audios": waveform,
            "text": answer,
            "task": instruction, # Use task field to pass the instruction
            # "router_label": item['router_label']
        }

class TWNMTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        重写 compute_loss 方法。
        在计算并返回主 loss 的同时，将分项 loss 保存到 self._last_logs 中，
        以便后续的 log 方法可以访问并记录它们。
        """
        outputs = model(inputs)
        loss = outputs["loss"]

        # --- 核心修改 Start ---
        # 创建一个字典来存储我们想要额外记录的指标
        self._last_logs = {
            "ce_loss": outputs["ce_loss"].item(),
            # "router_loss": outputs["router_loss"].item()
        }
        # --- 核心修改 End ---

        return (loss, outputs) if return_outputs else loss

    def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
        """
        重写 log 方法。
        在调用父类的 log 方法之前，将我们之前保存的自定义指标
        (ce_loss 和 router_loss) 添加到 logs 字典中。
        """
        # --- 核心修改 Start ---
        # 如果 self._last_logs 存在，则将其内容更新到 logs 字典中
        if hasattr(self, "_last_logs"):
            logs.update(self._last_logs)
        # --- 核心修改 End ---

        # 调用父类的 log 方法，它会处理所有日志记录逻辑，
        # 包括发送到 swanlab, console 等。
        super().log(logs, *args, **kwargs)

    def _save(self, output_dir: str | None = None, state_dict=None):
        # 如果 Trainer 在调用 _save 时没有提供 state_dict，我们从模型中获取
        if state_dict is None:
            state_dict = self.model.state_dict()

        # 筛选出可训练的参数
        trainable_state_dict = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # 在 state_dict 中找到对应的 tensor 并保存
                if name in state_dict:
                     trainable_state_dict[name] = state_dict[name]

        # 定义保存路径
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # 保存筛选后的 state_dict
        # 我们遵循Hugging Face的命名习惯，但你也可以自定义文件名
        torch.save(trainable_state_dict, os.path.join(output_dir, "pytorch_model.bin"))
        print(f"Only trainable parameters saved to {os.path.join(output_dir, 'pytorch_model.bin')}")

    # ==================== 新增/修改部分 Start ====================
    # 需求 3: 修改 evaluate 方法以计算并返回 accuracy
    def evaluate(
        self,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        if isinstance(eval_dataset, dict):
            metrics = {}
            for eval_name, _eval_dataset in eval_dataset.items():
                dataset_metrics = self.evaluate(
                    _eval_dataset,
                    ignore_keys=ignore_keys,
                    metric_key_prefix=f"{metric_key_prefix}_{eval_name}",
                )
                metrics.update(dataset_metrics)
            return metrics

        self._memory_tracker.start()
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        start_time = time.time()

        model = self._wrap_model(self.model, training=False, dataloader=eval_dataloader)
        num_examples = self.num_examples(eval_dataloader)
        model.eval()

        # --- Loss Calculation ---
        total_loss = 0.0
        total_ce_loss = 0.0
        total_router_loss = 0.0
        total_samples = 0
        with torch.no_grad():
            for i, batch_data in enumerate(eval_dataloader):
                # Trainer 会自动将数据移动到设备
                outputs = model(batch_data)
                loss = outputs["loss"]
                ce_loss = outputs["ce_loss"]
                router_loss = outputs["router_loss"]

                if torch.isfinite(loss):
                    batch_size = trainer_pt_utils.find_batch_size(batch_data)
                    total_samples += batch_size
                    total_loss += loss.item() * batch_size
                    total_ce_loss += ce_loss.item() * batch_size
                    total_router_loss += router_loss.item() * batch_size

        eval_time = time.time() - start_time
        avg_loss = total_loss / (total_samples if total_samples > 0 else 1)
        avg_ce_loss = total_ce_loss / (total_samples if total_samples > 0 else 1)
        avg_router_loss = total_router_loss / (total_samples if total_samples > 0 else 1)

        metrics = {
            "{}_loss".format(metric_key_prefix): avg_loss,
            "{}_ce_loss".format(metric_key_prefix): avg_ce_loss,
            "{}_router_loss".format(metric_key_prefix): avg_router_loss,
            "time": eval_time,
        }
        # save eval loss for each Evaluation Dataset
        self.state.log_history.append({"{}_loss".format(metric_key_prefix): avg_loss})

        total_batch_size = self.args.eval_batch_size * self.args.world_size
        metrics.update(
            trainer_utils.speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=num_examples,
                num_steps=math.ceil(num_examples / total_batch_size),
            )
        )
        logging.info(
            "{} for epoch {}, samples {}/{}, steps {}, loss {}, ce_loss {}, router_loss {}".format(
                metric_key_prefix,
                self.state.epoch,
                total_samples,
                num_examples,
                self.state.global_step,
                avg_loss,
                avg_ce_loss,
                avg_router_loss,
            )
        )

        self._memory_tracker.stop_and_update_metrics(metrics)
        return metrics


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    config_path: Optional[str] = field(default=None, metadata={"help": "setting files"})
    out_dir: Optional[str] = field(
        default=None, metadata={"help": "output dir for model"}
    )
    data_dir: Optional[str] = field(
        default=None, metadata={"help": "train and dev data file in dir"}
    )
    resume_checkpoint: Optional[str] = field(
        default="none", metadata={"help": "resume the model from checkpoint"}
    )
    rank: Optional[int] = field(
        default=0, metadata={"help": "the rank for distributed training"}
    )
    world_size: Optional[int] = field(
        default=1, metadata={"help": "the total gpu number for distributed training"}
    )
    init_model_path: Optional[str] = field(
        default="none", metadata={"help": "init the model weight by other model"}
    )
    lora_pretrain_checkpoint: Optional[str] = field(
        default="none", metadata={"help": "Path to a pre-LoRA checkpoint to initialize weights before starting LoRA training."}
    )

    def __post_init__(self):
        if self.config_path is None:
            raise ValueError("config path should not none")


"""
添加回调函数, 定期清空显存, 避免oom
"""
class LogCallback(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        if state.global_step % 500 == 0:  # 每500步清空一次显存
            torch.cuda.empty_cache()


def main():
    # 解决tokenizers并行化问题 - 在分布式训练前设置
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    parser = HfArgumentParser(DataTrainingArguments)
    data_args = parser.parse_args_into_dataclasses()[0]
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
    )
    # transformers.logging.set_verbosity_info()

    gpu_id = int(os.getenv("LOCAL_RANK", 0))
    logging.info("using gpu id:{}".format(gpu_id))
    torch.cuda.set_device(gpu_id)

    if gpu_id == 0:
        logging.info(data_args)
    
    is_main_process = (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized())

    if not is_main_process:
        transformers.logging.set_verbosity(logging.WARNING)
    else:
        transformers.logging.set_verbosity(logging.INFO)
    
    # 解决TF32警告 - 使用推荐的新API
    try:
        torch.backends.cudnn.conv.fp32_precision = 'tf32'
        torch.backends.cuda.matmul.fp32_precision = 'tf32'  # 使用tf32以获得更好的性能
    except AttributeError:
        # 如果新API不可用，回退到旧设置
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # load config
    set_seed(20)
    with open(data_args.config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    swanlab_config = config.get("swanlab_conf", {})
    run_name = None # 先初始化run_name
    
    if swanlab_config.get("enabled", False):
        # 1. 设置环境变量
        if "project" in swanlab_config:
            os.environ["SWANLAB_PROJECT"] = swanlab_config["project"]
        if "workspace" in swanlab_config:
            os.environ["SWANLAB_WORKSPACE"] = swanlab_config["workspace"]

        # 2. 动态生成 run_name
        template = swanlab_config.get("run_name_template", "run-{timestamp}")
        
        # 准备占位符的值
        replacements = {
            "model_name": "Spatial-LALM", # 可以来自config或硬编码
            "task": config.get("task", "unknown_task"),
            "timestamp": time.strftime("%Y%m%d-%H%M")
        }
        
        # 使用format方法填充模板
        run_name = template.format(**replacements)
        print(f"Generated SwanLab run name: {run_name}")


    if data_args.world_size > 1:
        # When using torchrun, it automatically sets up environment variables
        # for communication. We can use "env://" to initialize.
        dist.init_process_group("nccl", init_method="env://")
        world_size = int(os.environ['WORLD_SIZE'])
        rank = int(os.environ['RANK'])
        logging.info(
            "torch.distributed initialized with backend=nccl, init_method=env://, world-size=%d, rank=%d" 
            % (world_size, rank)
        )
        # 只让主进程（rank 0）执行可能产生大量日志的模型下载和初始化
        if rank != 0:
            transformers.logging.set_verbosity_error() # 非主进程只报告错误

        # 设置一个屏障，让 rank 0 先完成模型下载和缓存
        dist.barrier()
        # rank 0 完成后，所有进程都可以从缓存加载模型，不会再重复打印配置
        if rank == 0:
            transformers.logging.set_verbosity_info() # 主进程恢复正常日志级别

    dataset_batch_size = config["dataset_conf"]["batch_size"]
    num_workers = config["dataset_conf"]["num_workers"]
    num_epoch = config["epochs"]
    
    # ==================== 新增/修改部分 Start ====================
    # 需求 1, 2, 3: 修改 TrainingArguments 以实现按 epoch 保存/评估，按 step 记录
    training_args = TrainingArguments(
        output_dir=data_args.out_dir,
        seed=20,
        do_train=True,
        do_eval=False,
        dataloader_num_workers=num_workers,
        remove_unused_columns=False,
        greater_is_better=False,
        metric_for_best_model="eval_loss",
        load_best_model_at_end=False,
        per_device_train_batch_size=dataset_batch_size,
        per_device_eval_batch_size=dataset_batch_size,
        save_safetensors=False,
        # logging_dir=os.path.join(data_args.out_dir, "log"),
        max_grad_norm=config["clip_grad"],
        gradient_accumulation_steps=config["acc_grad"],
        ignore_data_skip=True,
        
        # --- 核心修改 ---
        save_strategy="epoch",            # 1. 每个 epoch 保存一次模型
        save_total_limit = 100,            # 一共保存10个模型
        logging_strategy="steps",         # 2. 按步数记录日志
        logging_steps=50,                # 2. 每 50 步记录一次训练 loss (你可以按需调整)
        # remove save_steps and eval_steps
        
        report_to="swanlab" if swanlab_config.get("enabled", False) else "none", # 根据配置决定是否上报
        run_name=run_name, # 传递动态生成的run_name

        bf16=True, # Enable mixed precision training
        ddp_find_unused_parameters=False, # 建议关闭以提升性能

        log_on_each_node=False,
    )
    # ==================== 新增/修改部分 End ====================
    
    training_args = training_args.set_optimizer(
        config["optim_args"]["name"],
        learning_rate=config["optim_args"]["lr"],
        weight_decay=config["optim_args"]["weight_decay"],
    )
    training_args = training_args.set_lr_scheduler(
        "cosine", num_epoch, warmup_ratio=config["warmup_radio"]
    )
    # Logging a bit differently now
    # training_args = training_args.set_logging(
    #     strategy="steps", steps=100, report_to=["swanlab"], level="info", first_step=True
    # )

    sample_rate, max_length = (
            config["dataset_conf"]["sample_rate"],
            config["dataset_conf"]["max_len"],
    )

    dataset_conf = config.get("dataset_conf", {})
    train_filename = dataset_conf.get("train_file", "train.jsonl")
    eval_filename = dataset_conf.get("eval_file", "eval.jsonl")

    train_data_file = os.path.join(data_args.data_dir, train_filename)
    eval_data_file = os.path.join(data_args.data_dir, eval_filename)
    
    full_train_dataset = JsonlDataset(train_data_file, sample_rate=sample_rate, max_length_seconds=max_length)
    # eval_dataset = JsonlDataset(eval_data_file, sample_rate=sample_rate, max_length_seconds=max_length)
    
    eval_ratio = 0 # 比如，我们用10%的数据做验证
    dataset_size = len(full_train_dataset)
    eval_size = int(eval_ratio * dataset_size)
    train_size = dataset_size - eval_size

    # if gpu_id == 0:
    #     print(f"训练集大小: {train_size}")
    #     print(f"验证集大小: {eval_size}")

    # 3. 使用 random_split 进行分割
    # 为了保证每次分割结果一致，可以设置一个随机种子
    generator = torch.Generator().manual_seed(42)
    train_dataset, eval_dataset = random_split(
        dataset=full_train_dataset,
        lengths=[train_size, eval_size],
        generator=generator
    )

    max_train_samples = config["dataset_conf"].get("max_train_samples")
    if max_train_samples and len(train_dataset) > max_train_samples:
        train_dataset = torch.utils.data.Subset(train_dataset, range(max_train_samples))

    max_eval_samples = config["dataset_conf"].get("max_eval_samples")
    if max_eval_samples and len(eval_dataset) > max_eval_samples:
        eval_dataset = torch.utils.data.Subset(eval_dataset, range(max_eval_samples))

    if gpu_id == 0:
        print(f"Loaded {len(train_dataset)} training samples from {train_data_file}")
        print(f"Loaded {len(eval_dataset)} evaluation samples from {eval_data_file}")

        for sample in train_dataset:
            print(f'trainset sample:\n{sample}')
            break
        for sample in eval_dataset:
            print(f'evalset sample:\n{sample}')
            break

    model = TWNM(config,lora_pretrain_ckpt_path=data_args.lora_pretrain_checkpoint)
    
    if data_args.init_model_path != "none":
        init_state = torch.load(data_args.init_model_path, map_location="cpu")
        model.load_state_dict(init_state, strict=False)
        # release memory
        del init_state

        if gpu_id == 0:
            logging.info("Loaded init weight from {}".format(data_args.init_model_path))

    # logging.info(model)
    
    # model.print_module_parameters()

    trainer = TWNMTrainer(
        model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collate_fn,
        compute_metrics=None,
        # 将我们自定义的回调和原有的回调都加入列表
        callbacks=[LogCallback()]
    )
    # ==================== 新增/修改部分 End ====================


    checkpoint = None
    if data_args.resume_checkpoint != "none":
        checkpoint = data_args.resume_checkpoint
        keys_to_ignore_on_save = [name for name, param in model.named_parameters() if not param.requires_grad]
        model._keys_to_ignore_on_save = keys_to_ignore_on_save
        if gpu_id == 0:
            logging.info("Resume from checkpoint: {}".format(data_args.resume_checkpoint))

    if gpu_id == 0:
        model.print_trainable_parameters()

    trainer.train(resume_from_checkpoint=checkpoint)
    logging.info("Training done.")


if __name__ == "__main__":
    main()