import os
import sys

import torch
import torchaudio
from omegaconf import ListConfig
from typing import Tuple
from omegaconf import OmegaConf
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_info
from torch import Tensor
from torch.amp import autocast
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    WhisperFeatureExtractor,
)

from modules.glm4video.flow_inference import AudioDecoder
from modules.glm4video.modeling_whisper import WhisperVQEncoder

from .base import BaseModel, HFMTrainer
from .utils import load_model_from_checkpoint, freeze_module


# ----------------------------- Pipeline ----------------------------- #
class humanexpertS1(BaseModel):
    def __init__(
        self,
        nfeats,
        config_audio_tokenizer,
        config_language_model,
        config,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.config = config
        self.nfeats = nfeats
        self.mode = self.config.get("mode", "train")

        self.max_length = self.config.model.get("max_length", 8192)
        self.task = self.config.model.get("task", "uncond")

        self.audio_interleave_window = self.config.model.get("interleave_window", 26)
        self.text_fps = self.config.model.get("text_fps", 6.25)
        self.audio_fps = self.config.model.get("audio_fps", 12.5)
        self.motion_fps = self.config.model.get("motion_fps", 25)
        self.text_interleave_window = int(
            self.audio_interleave_window * self.text_fps / self.audio_fps
        )
        self.motion_interleave_window = int(
            self.audio_interleave_window * self.motion_fps / self.audio_fps
        )
        self.shared_audio_tokens = self.config.model.get("shared_audio_tokens", True)

        self.audio_bos_eos = self.config.model.get("audio_bos_eos", False)
        self.skip_tokens = self.config.model.get("skip_tokens", 2)
        self.response_text = self.config.model.get("response_text", True)
        self.behaviour_special_token = self.config.model.get(
            "behaviour_special_token", False
        )
        self.only_loss_motion = self.config.model.get("only_loss_motion", False)

        self.freeze_lm = self.config.model.get("freeze_lm", False)
        self.load_lm = self.config.model.get("load_lm", True)
        self.lora_config = self.config.model.get("lora_config", None)
        if self.lora_config:
            self.lora_config = OmegaConf.to_container(self.lora_config)

        # Load the audio tokenizer and decoder paths
        self._load_audio_tokenizer(config_audio_tokenizer)
        self._load_audio_decoder(config_audio_tokenizer)
        self._load_language_model(config_language_model)
        self._add_audio_tokens(pretrained=True)

        if hasattr(config.model, "model_path") and "not_loading" not in kwargs:
            self.load_state_dict(
                load_model_from_checkpoint(config.model.model_path, self.state_dict())
            )

    def _load_audio_tokenizer(self, config_audio_tokenizer):
        if self.mode != "train":
            sys.path.insert(0, config_audio_tokenizer.cosyvoice_path)
            sys.path.insert(0, config_audio_tokenizer.matcha_path)
            self.audio_tokenizer = WhisperVQEncoder.from_pretrained(
                config_audio_tokenizer.model_path
            ).eval()
            self.audio_feature_extractor = WhisperFeatureExtractor.from_pretrained(
                config_audio_tokenizer.model_path
            )
        else:
            self.sample_rate = 16000
            self.speech_size = 16384

    def _load_audio_decoder(self, config_audio_tokenizer):
        if self.mode != "train":
            flow_config = os.path.join(config_audio_tokenizer.flow_path, "config.yaml")
            flow_checkpoint = os.path.join(config_audio_tokenizer.flow_path, "flow.pt")
            hift_checkpoint = os.path.join(config_audio_tokenizer.flow_path, "hift.pt")
            self.audio_decoder = AudioDecoder(
                config_path=flow_config,
                flow_ckpt_path=flow_checkpoint,
                hift_ckpt_path=hift_checkpoint,
            )
            self.audio_decoder.flow = self.audio_decoder.flow.eval()
            self.audio_decoder.hift = self.audio_decoder.hift.eval()
            self.sample_rate = 16000
            self.speech_size = 16384
        else:
            self.sample_rate = 16000
            self.speech_size = 16384

    def _load_language_model(self, config_language_model):
        model_path = config_language_model.model_path
        self.language_tokenizer = AutoTokenizer.from_pretrained(
            model_path, legacy=True, trust_remote_code=True
        )
        if self.load_lm:
            self.language_model = AutoModelForCausalLM.from_pretrained(
                model_path, trust_remote_code=True
            ).train()
        else:
            language_model_config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=True
            )
            self.language_model = AutoModelForCausalLM.from_config(
                language_model_config, trust_remote_code=True
            ).train()
        self.language_tokenizer.pad_token = self.language_tokenizer.eos_token
        self.lm_hidden_size = self.language_model.config.hidden_size

        if self.lora_config:
            self._apply_lora(self.lora_config)
        elif self.freeze_lm:
            freeze_module(self.language_model)

    def _apply_lora(self, lora_config):
        from peft import LoraConfig, get_peft_model

        peft_config = LoraConfig(
            # bias="none",
            # task_type="CAUSAL_LM",
            # target_modules=["lm_head", "embed_tokens"],
            r=lora_config.get("r", 32),
            lora_alpha=lora_config.get("lora_alpha", 32),
            lora_dropout=lora_config.get("lora_dropout", 0.1),
            target_modules=lora_config.get("target_modules", []),
            exclude_modules=lora_config.get("exclude_modules", []),
            modules_to_save=lora_config.get("modules_to_save", []),
        )
        self.language_model = get_peft_model(self.language_model, peft_config)

    def _add_audio_tokens(self, pretrained=False):
        self.a_codebook_layer = 1
        self.a_codebook_size = self.speech_size
        self.a_start_id = self.language_tokenizer.convert_tokens_to_ids("<|audio_0|>")

        # Determine the factor based on whether audio tokens are shared or not.
        factor = 1 if self.shared_audio_tokens else self.a_codebook_layer

        # Calculate base offset for special tokens (SOS, EOS, PAD, UNK)
        base_offset = self.a_codebook_size * factor

        # Define special tokens and convert them to IDs
        # FIXME: Hardcoded for glm
        self.special_tokens = {
            "a_sos": "<|begin_of_audio|>",
            "a_eos": "<|end_of_audio|>",
            # "a_pad": "<|pad_of_audio|>",
            # "a_unk": "<|unk_of_audio|>",
            # "m_out": "<|reserved_151357|>",
        }
        self.special_token_ids = {
            f"{key}_id": self.language_tokenizer.convert_tokens_to_ids(token)
            for key, token in self.special_tokens.items()
        }

        for key, token in self.special_token_ids.items():
            setattr(self, key, token)

        # Number of new tokens to be added
        new_tokens_count = base_offset + 4

        try:
            if pretrained:
                return
            # Generate new token list and add them to the language tokenizer
            added_tokens = [f"<|audio_{i}|>" for i in range(new_tokens_count)]
            self.language_tokenizer.add_tokens(added_tokens)

            # Resize the language model embedding layer to accommodate the new tokens
            self.language_model.resize_token_embeddings(len(self.language_tokenizer))

            # Log success message
            print(f"Successfully added {len(added_tokens)} audio tokens.")
        except Exception as e:
            # Log failure message with error details
            raise ValueError(
                f"Failed to add {new_tokens_count} audio tokens: {e}"
            ) from e

    def _get_token_parameters(self, token_type):
        """
        Retrieve parameters based on the token type.
        """
        if token_type not in ["m", "a"]:
            raise ValueError(f"Unknown token type: {token_type}")

        return (
            getattr(self, f"{token_type}_codebook_layer"),
            getattr(self, f"{token_type}_codebook_size"),
            getattr(self, f"shared_{token_type}_tokens"),
            getattr(self, f"{token_type}_start_id"),
            getattr(self, f"{token_type}_sos"),
            getattr(self, f"{token_type}_eos"),
            getattr(self, f"{token_type}_pad"),
            getattr(self, f"{token_type}_unk"),
        )

    def _apply_chat_template(
        self,
        instructions,
        targets=None,
        audio_input=False,
        behaviour_output=False,
        generation_prefix=False,
    ) -> list[str]:
        """
        Applies a chat template to the given instructions and targets.
        """
        labels = []
        if targets is None:
            targets = [""] * len(instructions)

        system_prompt_base = "User will provide you with a {} instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."

        # Batch processing
        for user_input, target_output in zip(instructions, targets):
            label = ""
            # System prompt based on audio input
            system_prompt = system_prompt_base.format(
                "speech" if audio_input else "text"
            )

            # Apply chat template
            if "<|system|>" not in user_input:
                label += f"<|system|>\n{system_prompt}\n"
            label += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
            if behaviour_output and self.behaviour_special_token:
                label += self.m_special
            label += target_output
            labels.append(label)

        return labels

    def _process_segment(
        self, ids, idx, window_size, add_bos_eos=False, bos_id=None, eos_id=None
    ):
        """
        Helper function to process a segment of tokens.
        :param ids: The list of token IDs.
        :param idx: Current index in the list of token IDs.
        :param window_size: Number of tokens to include from the current position.
        :param add_bos_eos: Boolean flag indicating whether to add BOS/EOS tokens.
        :param bos_id: Begin of sequence token ID.
        :param eos_id: End of sequence token ID.
        :return: Processed list of tokens and updated index.
        """
        segment = []
        if add_bos_eos and bos_id is not None:
            segment.append(bos_id)

        segment.extend(ids[idx : min(idx + window_size, len(ids))])

        if add_bos_eos and eos_id is not None:
            segment.append(eos_id)

        return segment, idx + window_size

    def _process_text_ids(self, text):
        """
        Processes the text IDs by tokenizing and skipping initial tokens.

        Args:
            text (str): The input text to be processed.
            skip_tokens (int): Number of initial tokens to skip.

        Returns:
            list: Processed text IDs.
        """
        if not self.response_text:
            return []
        text_ids = (
            self.language_tokenizer(text, return_tensors="pt").input_ids[0].tolist()
        )
        return text_ids[self.skip_tokens :]

    def _process_audio_ids(self, audio_tokens, length):
        """
        Processes the audio IDs by converting tokens to string and then back to IDs, skipping initial tokens.

        Args:
            audio_tokens (torch.Tensor): Audio tokens to be processed.
            length (int): Length of audio tokens to process.
            skip_tokens (int): Number of initial tokens to skip.

        Returns:
            list: Processed audio IDs.
        """
        audio_strs = self.audio_tokens_to_string(audio_tokens[:length], False, False)
        audio_ids = (
            self.language_tokenizer(audio_strs, return_tensors="pt")
            .input_ids[0]
            .tolist()
        )
        return audio_ids[self.skip_tokens :]

    def _process_motion_ids(self, motion_tokens, length):
        """
        Processes the motion IDs by converting tokens to string and then back to IDs, skipping initial tokens.

        Args:
            motion_tokens (torch.Tensor): motion tokens to be processed.
            length (int): Length of motion tokens to process.
            skip_tokens (int): Number of initial tokens to skip.

        Returns:
            list: Processed motion IDs.
        """
        motion_strs = self.motion_tokens_to_string(length, False, False)
        motion_ids = (
            self.language_tokenizer(motion_strs, return_tensors="pt")
            .input_ids[0]
            .tolist()
        )
        return motion_ids[self.skip_tokens :]

    def _construct_outputs(
        self,
        captions,
        responses,
        audio_tokens,
        audio_lengths,
        motion_tokens=None,
        motion_lengths=None,
    ):
        """
        Constructs outputs by interleaving text and audio tokens.

        Args:
            captions (list): List of captions.
            responses (list): List of responses.
            audio_tokens (torch.Tensor): Tensor containing audio tokens.
            audio_lengths (list): List of lengths for each audio sequence.

        Returns:
            list: Constructed outputs.
        """
        outputs = []
        for b in range(len(captions)):
            # Process text and audio IDs
            text_ids = self._process_text_ids(responses[b])
            audio_ids = self._process_audio_ids(audio_tokens[b], audio_lengths[b])
            if motion_tokens and motion_lengths:
                motion_ids = self._process_motion_ids(motion_tokens, motion_lengths[b])
            else:
                motion_ids = []

            # Add EOS or specific token
            audio_ids.append(self.language_tokenizer.convert_tokens_to_ids("<|user|>"))

            # Interleave text, audio, and motion tokens
            output = self.interleave_text_audio_motion_tokens(
                text_ids,
                audio_ids,
                motion_ids,
                audio_bos_eos=self.audio_bos_eos,
                return_type="str",
            )

            outputs.append(output)

        return outputs

    def _find_positions(self, token_ids, start_condition, end_condition):
        """
        Helper function to find the beginning and ending positions of specific segments.

        Args:
            token_ids (list): List of token IDs.
            start_id (int): The starting ID for the target tokens indicating the beginning of a segment.
            end_condition (function): A condition to determine the end of a sequence of target tokens within a segment.

        Returns:
            tuple: Two lists containing the beginning and ending positions of each segment.
        """
        bos_positions = []
        eos_positions = []

        i = 0
        while i < len(token_ids):
            # Check if current position is the start of a new segment
            if start_condition(token_ids[i]):
                start_pos = i
                i += 1  # Move past the start_id

                # Find the end position based on end_condition
                while i < len(token_ids) and not end_condition(token_ids[i]):
                    i += 1

                # Append found positions
                bos_positions.append(start_pos)
                eos_positions.append(i)
            else:
                i += 1

        return bos_positions, eos_positions

    def get_audio_token_positions(
        self, input_ids: Tensor
    ) -> Tuple[list[list[int]], list[list[int]]]:
        all_audio_bos_positions = []
        all_audio_eos_positions = []

        for tokenized_input in input_ids:
            tokenized_input = tokenized_input.cpu().tolist()
            audio_bos_positions, audio_eos_positions = self._find_positions(
                tokenized_input,
                lambda x: x >= self.a_start_id,
                lambda x: x < self.a_start_id,
            )
            if not audio_bos_positions:
                audio_bos_positions = [
                    len(tokenized_input) - self.audio_interleave_window
                ]
                audio_eos_positions = [len(tokenized_input) - 1]

            all_audio_bos_positions.append(audio_bos_positions)
            all_audio_eos_positions.append(audio_eos_positions)

        return all_audio_bos_positions, all_audio_eos_positions

    def get_response_token_positions(
        self, input_ids: Tensor
    ) -> Tuple[list[int], list[int]]:
        eos_token_id = self.language_tokenizer.eos_token_id
        bos_token_id = self.language_tokenizer.convert_tokens_to_ids("<|assistant|>")

        all_audio_bos_positions = []
        all_audio_eos_positions = []

        for tokenized_input in input_ids:
            tokenized_input = tokenized_input.cpu().tolist()
            audio_bos_positions = [
                i
                for i, token_id in enumerate(tokenized_input)
                if token_id == bos_token_id
            ]
            if not audio_bos_positions:
                audio_bos_positions = [0]
            last_assistant_pos = audio_bos_positions[-1]

            try:
                eos_pos = tokenized_input.index(eos_token_id, last_assistant_pos)
            except ValueError:
                eos_pos = len(tokenized_input)

            all_audio_bos_positions.append(last_assistant_pos + 1)
            all_audio_eos_positions.append(eos_pos)

        return all_audio_bos_positions, all_audio_eos_positions

    def interleave_text_audio_motion_tokens(
        self, text_ids, audio_ids, motion_ids, audio_bos_eos, return_type="pt"
    ):
        """
        Interleaves text, audio, and motion tokens according to specified window sizes.
        """
        text_window = self.text_interleave_window
        audio_window = self.audio_interleave_window
        motion_window = self.motion_interleave_window

        text_idx = 0
        audio_idx = 0
        motion_idx = 0

        all_tokens = []

        while (
            text_idx < len(text_ids)
            or audio_idx < len(audio_ids)
            or motion_idx < len(motion_ids)
        ):
            # Process text tokens
            if text_idx < len(text_ids):
                segment, new_idx = self._process_segment(
                    text_ids, text_idx, text_window
                )
                all_tokens.extend(segment)
                text_idx = new_idx

            # Process audio tokens
            if audio_idx < len(audio_ids):
                segment, new_idx = self._process_segment(
                    audio_ids,
                    audio_idx,
                    audio_window,
                    add_bos_eos=audio_bos_eos,
                    bos_id=self.a_sos_id,
                    eos_id=self.a_eos_id,
                )
                all_tokens.extend(segment)
                audio_idx = new_idx

            # Process motion tokens
            if motion_idx < len(motion_ids):
                segment, new_idx = self._process_segment(
                    motion_ids,
                    motion_idx,
                    motion_window,
                    add_bos_eos=True,
                    bos_id=self.m_sos_id,
                    eos_id=self.m_eos_id,
                )
                all_tokens.extend(segment)
                motion_idx = new_idx

        # Convert to desired return type
        if return_type == "pt":
            all_tokens = torch.tensor(all_tokens, dtype=torch.long)
        elif return_type == "str":
            all_tokens = self.language_tokenizer.decode(
                all_tokens, spaces_between_special_tokens=False
            )

        return all_tokens

    def extract_audio_tokens(self, interleaved_tokens):
        """
        Extracts audio tokens from interleaved tokens based on offsets.

        Args:
            interleaved_tokens (torch.Tensor): Interleaved token IDs (batch, sequence_length).

        Returns:
            torch.Tensor: Separated audio tokens.
        """
        # Filter out special tokens and select tokens that are >= a_start_id
        audio_mask = (interleaved_tokens >= self.a_start_id) & ~torch.isin(
            interleaved_tokens,
            torch.tensor(
                list(self.special_token_ids.values()), device=interleaved_tokens.device
            ),
        )

        # Apply mask to get audio tokens and adjust their IDs by subtracting a_start_id
        audio_tokens = interleaved_tokens[audio_mask] - self.a_start_id

        return audio_tokens

    def _forward_dec(
        self,
        instructions: list[str],
        responses: list[str],
        audio_tokens: Tensor,
        audio_lengths: list[int],
    ):
        outputs = []

        # Construct the output string
        outputs = self._construct_outputs(
            instructions,
            responses,
            audio_tokens,
            audio_lengths,
        )

        # Step 2: Generate labels
        # if apply_chat_template:
        #     labels = self._apply_chat_template(instructions, outputs)
        # else:
        labels = outputs

        # Step 3: Tokenization
        tokenized_inputs = self.language_tokenizer(
            labels,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        ).to(audio_tokens.device)

        # Step 5: Prepare labels for loss calculation
        label_ids = tokenized_inputs.input_ids.clone()

        # Step 6: Language model forward pass
        with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            outputs = self.language_model(
                input_ids=tokenized_inputs.input_ids,
                attention_mask=tokenized_inputs.attention_mask,
                labels=label_ids,
                output_hidden_states=False,
            )

        # Step 8: Return results
        return_dict = {
            "loss": outputs.loss,
        }
        return return_dict

    def _generate_dec(
        self,
        captions: list[str],
        motion_tokens: Tensor,
        audio_tokens: Tensor,
        lengths: list[int],
        audio_lengths: list[int],
    ):
        if isinstance(self.task, ListConfig):
            self.task = list(self.task)
        if isinstance(self.task, str):
            max_input_length = self.max_length
            if self.task == "uncond":
                outputs = ["" for _ in range(len(captions))]
            elif self.task == "audio_prefix":
                crop_audio_tokens = audio_tokens[:, : self.interleave_window]
                outputs = []
                for i in range(crop_audio_tokens.shape[0]):
                    outputs.append(self.audio_tokens_to_string(crop_audio_tokens[i]))
            else:
                raise ValueError(f"Unknown task: {self.task}")
        elif isinstance(self.task, list):
            outputs = self.task
            max_input_length = self.text_interleave_window
        else:
            raise ValueError(f"Unknown task type: {self.task}")

        # labels = self._apply_chat_template(captions, outputs, generation=True)
        labels = outputs
        # print("---" * 10)
        # print(labels[0])
        # print("---" * 10)
        tokenized_inputs = self.language_tokenizer(
            labels,
            padding=True,
            max_length=max_input_length,
            truncation=True,
            return_tensors="pt",
            # add_special_tokens=False,
        )

        input_ids = tokenized_inputs.input_ids.to(audio_tokens.device)

        with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            outputs = self.language_model.generate(
                input_ids,
                # attention_mask=attention_mask,
                pad_token_id=self.language_tokenizer.pad_token_id,
                do_sample=True,
                max_new_tokens=self.max_length,
                output_hidden_states=False,
                return_dict_in_generate=True,
                # max_length=self.max_length,
            )

        outputs_string = self.language_tokenizer.batch_decode(
            outputs.sequences, skip_special_tokens=True
        )

        audio_tokens_out = []
        for i in range(outputs.sequences.shape[0]):
            audio_tokens_i = self.extract_audio_tokens(outputs.sequences[i])
            audio_tokens_out.append(audio_tokens_i)

        return outputs_string, None, audio_tokens_out

    def audio_tokens_to_string(self, audio_tokens, bos=True, eos=True):
        audio_tokens_str = []
        if bos:
            audio_tokens_str = [self.a_sos]
        audio_tokens_str.extend(
            [f"<|audio_{t}|>" for t in audio_tokens.long().tolist()]
        )
        if eos:
            audio_tokens_str.append(self.a_eos)
        return "".join(audio_tokens_str)

    def motion_tokens_to_string(self, motion_length, bos=True, eos=True):
        motion_tokens_str = []
        if bos:
            motion_tokens_str = [self.m_sos]
        motion_tokens_str.extend([self.m_pad] * motion_length)
        if eos:
            motion_tokens_str.append(self.m_eos)
        return "".join(motion_tokens_str)

    @autocast(device_type="cuda", enabled=False, dtype=torch.float32)
    def tokenize_audio(self, audio, lengths_audio, start_ratio, end_ratio):
        pass

    @autocast(device_type="cuda", enabled=False, dtype=torch.float32)
    def decode_audio(self, audio_tokens):
        audio_wav = self.audio_decoder.offline_inference(audio_tokens.unsqueeze(0))
        return audio_wav

    def forward_train(
        self,
        motion_tokens,
        mot_token_lengths,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        empty_lst = [""] * len(motion_tokens)

        results = self._forward_dec(empty_lst, empty_lst, audio_tokens, audio_lengths)

        return {
            "loss": results["loss"],
        }

    def forward_eval(
        self,
        motion_tokens,
        mot_token_lengths,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        empty_lst = [""] * len(motion_tokens)

        forward_dict = self.forward_train(
            motion_tokens,
            mot_token_lengths,
            audio_tokens,
            audio_lengths,
        )

        outputs, motion_tokens_out, audio_tokens_out = self._generate_dec(
            empty_lst, motion_tokens, audio_tokens, mot_token_lengths, audio_lengths
        )

        forward_dict.update(
            {
                "outputs": outputs,
                "motion_preds": motion_tokens_out,
                "motion_labels": motion_tokens,
                "audio_preds": audio_tokens_out,
                "audio_labels": audio_tokens,
            }
        )

        if isinstance(self.task, list):
            forward_dict.update(
                {
                    "should_stop": True,
                }
            )

        return forward_dict

    def forward(self, **kwargs):
        """
        General forward method that directs to either training or evaluation mode based on the current state.

        Args:
            poses: Input pose sequences.
            lengths: Length of each pose sequence.
            ignore_mask: Mask to ignore specific parts of the sequence during loss calculation.

        Returns:
            Output dictionary from either forward_train or forward_eval method.
        """
        if self.training:
            return_dict = self.forward_train(**kwargs)
        else:
            return_dict = self.forward_eval(**kwargs)

        return_dict.update(kwargs)
        return return_dict

    def training_step(self, batch, batch_idx):
        """
        Perform a single training step, logging relevant metrics.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Loss value for the current training step.
        """
        if batch is None:
            return torch.tensor(0.0, device=batch["poses"].device)

        output = self(**batch)
        self.log_progress_bar(
            {
                "lr": self.trainer.optimizers[0].param_groups[0]["lr"],
                "loss": output["loss"],
            }
        )
        log_dict = {"train/loss": output["loss"]}
        self.log_dict(
            log_dict,
            on_step=False,
            on_epoch=True,
            batch_size=len(batch["lengths_poses"]),
            sync_dist=True,
        )
        return output["loss"]

    def validation_step(self, batch, batch_idx):
        """
        Perform a single validation step, logging relevant metrics and triggering visualization if required.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """
        if batch is None:
            return None

        output = self(**batch)
        metric = {}

        self.log_dict(
            metric,
            on_epoch=True,
            batch_size=len(batch["lengths_poses"]),
            sync_dist=True,
        )

        if self.trainer.sanity_checking or rank_zero_only.rank != 0 or batch_idx > 2:
            return output

        global_step = self.global_step

        save_dir = os.path.join(
            self.hparams.output_dir, "samples", f"step-{global_step}"
        )
        os.makedirs(save_dir, exist_ok=True)
        print(f"Saving samples to {save_dir}")

        for i in range(len(output["audio_preds"])):
            # Decode audio tokens
            tts_speech = self.decode_audio(output["audio_preds"][i])
            with open(os.path.join(save_dir, f"{batch_idx}_{i}_pred.wav"), "wb") as f:
                torchaudio.save(f, tts_speech, 22050, format="wav")

            # Save strings
            with open(os.path.join(save_dir, f"{batch_idx}_{i}_pred.txt"), "w") as f:
                f.write(output["outputs"][i])

        return output

    def test_step(self, batch, batch_idx):
        """
        Perform a single test step, leveraging the validation step logic.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """
        return self.validation_step(batch, batch_idx)


# ----------------------------- Trainer ----------------------------- #messages = []
class humanexpertTrainer(HFMTrainer):
    def entrypoint(self):
        super().entrypoint()
        if self.config.get("mode", "train") == "train":
            self.fit()
        elif self.config.mode == "eval":
            self.test()
        else:
            raise ValueError(f"Unknown mode: {self.config.mode}")

    def configure_datasets(self):
        """
        Configure the training and validation datasets.
        """

        from data.faces.data_collator import data_collator_token as data_collator
        from data.faces.dataset_audio_hf import AudioDatasetHF

        if self.config.get("mode", "train") == "train":
            self.train_dataset = AudioDatasetHF(
                dataset_paths=self.config.data.dataset_paths,
                dataset_ratios=self.config.data.dataset_ratios,
                splits=self.config.data.train_splits,
                audio_token_dir=self.config.data.audio_token_dir,
                len_factor=self.config.data.get("len_factor", 1),
            )
        else:
            self.train_dataset = None

        self.eval_dataset = AudioDatasetHF(
            dataset_paths=self.config.data.dataset_paths,
            dataset_ratios=self.config.data.dataset_ratios,
            splits=self.config.data.eval_splits,
            audio_token_dir=self.config.data.audio_token_dir,
            len_factor=self.config.data.get("len_factor", 1),
        )
        self.data_collator = data_collator

    def configure_model(self):
        """
        Configure the model for training.
        """

        config_audio_tokenizer = self.config.model.audio_tokenizer.copy()
        config_language_model = self.config.model.language_model.copy()

        self.model = humanexpertS1(
            nfeats=self.nfeats,
            config_audio_tokenizer=config_audio_tokenizer,
            config_language_model=config_language_model,
            config=self.config,
            output_dir=self.dir_params["output_dir"],
            eval_dataset=self.eval_dataset,
            train_dataset=self.train_dataset,
        )
        # Load the model from the given path
        self.load_model()
        # print(self.model)
