import os

import numpy as np
import torch
from torch import nn
import torchaudio
import sys
from omegaconf import ListConfig, OmegaConf
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_info
from torch import Tensor
from torch.amp.autocast_mode import autocast
from torch.nn.utils.rnn import pad_sequence

from common.config import create_object
from modules.mar.l2loss import L2Loss

from .humanexpert_s1 import humanexpertS1
from .humanexpert_s3 import humanexpertTrainer as HFMTrainer
from .utils import freeze_module, load_model_from_checkpoint

from modules.mot.mot_glm import IGNORE_INDEX


# ----------------------------- Pipeline ----------------------------- #
class humanexpertS3(humanexpertS1):
    def __init__(
        self,
        **kwargs,
    ):
        super().__init__(not_loading=True, **kwargs)

        if "config_pose_tokenizer" in kwargs and kwargs["config_pose_tokenizer"]:
            self.pose_tokenize = True
            self._load_pose_tokenizer(kwargs["config_pose_tokenizer"])

        else:
            self.pose_tokenize = False

        self.freeze_lm = self.config.model.get("freeze_lm", False)
        self.load_lm = self.config.model.get("load_lm", True)
        self.eval_training = self.config.model.get("eval_training", False)

        # FIXME: hardcoded
        self.eos = "<|user|>"
        self.eos_id = self.language_tokenizer.convert_tokens_to_ids(self.eos)
        self.m_sos = "<|reserved_151360|>"
        self.m_sos_id = self.language_tokenizer.convert_tokens_to_ids(self.m_sos)
        self.m_eos = "<|reserved_151361|>"
        self.m_eos_id = self.language_tokenizer.convert_tokens_to_ids(self.m_eos)
        self.m_pad = "<|reserved_151362|>"
        self.m_pad_id = self.language_tokenizer.convert_tokens_to_ids(self.m_pad)
        self.m_special = "<|reserved_151357|>"
        self.m_special_id = self.language_tokenizer.convert_tokens_to_ids(
            self.m_special
        )

        if "not_loading" not in kwargs and hasattr(self.config.model, "model_path"):
            self.load_state_dict(
                load_model_from_checkpoint(
                    self.config.model.model_path, self.state_dict()
                )
            )

        self.hidden_dict = {}
        if self.config.model.get("hidden_dict_path", None) and self.mode == "train":
            rank_zero_info(
                f"Loading hidden dict from {self.config.model.hidden_dict_path}"
            )
            self.hidden_dict = torch.load(
                self.config.model.hidden_dict_path, weights_only=False
            )
            self.language_model = None

        if not self.load_lm:
            self.language_model = None

        if self.lora_config:
            self._apply_lora(self.lora_config)

        total_params = 0
        for name, p in self.language_model.named_parameters():
            num_params = p.numel()
            total_params += num_params
            print(
                f"{name:<60} | requires_grad={p.requires_grad} | params={num_params:,}"
            )

        # TODO fix mutli-batch inference
        self.task_id = -1

    def _load_pose_tokenizer(self, config_pose_tokenizer):
        self.pose_tokenizer = create_object(config_pose_tokenizer)
        self.pose_tokenizer.load_state_dict(
            load_model_from_checkpoint(
                config_pose_tokenizer.model_path,
                self.pose_tokenizer.state_dict(),
                "pose_tokenizer",
            )
        )
        freeze_module(self.pose_tokenizer)

    def _load_language_model(self, config_language_model):
        if self.config.model.get("mot", True):
            from modules.mot.mot_glm import (
                AutoTokenizer,
                ChatGLMConfig,
                MOTChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration,
            )
        else:
            from modules.mot.lora_glm import (
                AutoTokenizer,
                ChatGLMConfig,
                LoRAChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration,
            )

        model_path = config_language_model.model_path
        self.language_tokenizer = AutoTokenizer.from_pretrained(
            model_path, legacy=True, trust_remote_code=True
        )
        language_model_config = ChatGLMConfig.from_pretrained(model_path)
        if self.config.model.language_model.get("config", None):
            for key, value in self.config.model.language_model.config.items():
                if hasattr(language_model_config, key):
                    setattr(language_model_config, key, value)
                else:
                    raise ValueError(f"Key {key} not found in language model config.")
        language_model_config.config_dforcing = self.config.dforcing
        language_model_config.mot_dim[0] = self.nfeats
        self.language_model = ChatGLMForConditionalGeneration.from_pretrained(
            model_path, config=language_model_config
        ).train()
        self.language_tokenizer.pad_token = self.language_tokenizer.eos_token
        self.lm_hidden_size = self.language_model.config.hidden_size

        # For generation
        self.language_model.audio_offset = (
            self.language_tokenizer.convert_tokens_to_ids("<|audio_0|>")
        )

        for name, p in self.language_model.named_parameters():
            if self.freeze_lm and "mot_" not in name and "lora_" not in name:
                p.requires_grad = False

    def _construct_outputs(
        self,
        captions,
        responses,
        audio_tokens,
        audio_lengths,
        motion_tokens,
        motion_lengths,
        first_chunk_only=False,
    ):
        """
        Constructs outputs by interleaving text and audio tokens.

        Args:
            captions (list): List of captions.
            responses (list): List of responses.
            audio_tokens (torch.Tensor): Tensor containing audio tokens.
            audio_lengths (list): List of lengths for each audio sequence.

        Returns:
            list: Constructed outputs.
        """
        outputs = []
        mot_outputs = []
        for b in range(len(captions)):
            # Process text and audio IDs
            text_ids = self._process_text_ids(responses[b])
            min_length = min(
                audio_lengths[b] / self.audio_interleave_window,
                motion_lengths[b] / self.motion_interleave_window,
            )
            audio_length = int(min_length * self.audio_interleave_window)
            motion_length = int(min_length * self.motion_interleave_window)
            audio_ids = self._process_audio_ids(audio_tokens[b], audio_length)
            motion_ids = motion_tokens[b][:motion_length]
            if first_chunk_only:
                text_ids = text_ids[: self.text_interleave_window]
                audio_ids = audio_ids[: self.audio_interleave_window]
                motion_ids = []

            # else:
            # Make sure motion_ids and audio_ids are the same length ratio with fps
            # Add EOS or specific token
            # audio_ids.append(
            #     self.language_tokenizer.convert_tokens_to_ids("<|user|>")
            # )

            # Interleave text, audio, and motion tokens
            output, mot_output = self.interleave_text_audio_motion_tokens(
                text_ids,
                audio_ids,
                motion_ids,
                audio_bos_eos=self.audio_bos_eos,
                return_type="str",
            )

            outputs.append(output)
            mot_outputs.append(mot_output)

        return outputs, mot_outputs

    def _process_ignore_segment(self, ids, idx, window, mode="pad"):
        if mode == "pad":
            token_id = self.m_special_id
        elif mode == "motion":
            token_id = self.m_pad_id
        else:
            raise ValueError(f"Unknown mode: {mode}")

        segment = [token_id] * min(window, len(ids) - idx)
        new_idx = idx + window
        return segment, new_idx

    def interleave_text_audio_motion_tokens(
        self, text_ids, audio_ids, motion_ids, audio_bos_eos, return_type="pt"
    ):
        """
        Interleaves text, audio, and motion tokens according to specified window sizes.
        """
        text_window = self.text_interleave_window
        audio_window = self.audio_interleave_window
        motion_window = self.motion_interleave_window

        text_idx = 0
        audio_idx = 0
        motion_idx = 0

        all_tokens = []
        mot_tokens = []

        while (
            text_idx < len(text_ids)
            or audio_idx < len(audio_ids)
            or motion_idx < len(motion_ids)
        ):
            # Process text tokens
            if text_idx < len(text_ids):
                segment, new_idx = self._process_segment(
                    text_ids, text_idx, text_window
                )
                all_tokens.extend(segment)
                pad_segment, _ = self._process_ignore_segment(
                    text_ids, text_idx, text_window
                )
                mot_tokens.extend(pad_segment)
                text_idx = new_idx

            # Process audio tokens
            if audio_idx < len(audio_ids):
                segment, new_idx = self._process_segment(
                    audio_ids,
                    audio_idx,
                    audio_window,
                    add_bos_eos=audio_bos_eos,
                    bos_id=self.a_sos_id,
                    eos_id=self.a_eos_id,
                )
                all_tokens.extend(segment)
                pad_segment, _ = self._process_ignore_segment(
                    audio_ids, audio_idx, audio_window
                )
                mot_tokens.extend(pad_segment)
                audio_idx = new_idx

            # Process motion tokens
            if motion_idx < len(motion_ids):
                segment, new_idx = self._process_ignore_segment(
                    motion_ids, motion_idx, motion_window
                )
                all_tokens.extend(segment)
                mot_tokens.extend(
                    self._process_ignore_segment(
                        motion_ids, motion_idx, motion_window, "motion"
                    )[0]
                )
                motion_idx = new_idx

        # Convert to desired return type
        if return_type == "pt":
            all_tokens = torch.tensor(all_tokens, dtype=torch.long)
            mot_tokens = torch.tensor(mot_tokens, dtype=torch.long)
        elif return_type == "str":
            all_tokens = self.language_tokenizer.decode(
                all_tokens, spaces_between_special_tokens=False
            )
            mot_tokens = self.language_tokenizer.decode(
                mot_tokens, spaces_between_special_tokens=False
            )

        return all_tokens, mot_tokens

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _encode_motion(self, poses, lengths):
        if self.pose_tokenize:
            return self.pose_tokenizer.encode(poses, lengths=lengths, mask=None)
        else:
            return poses

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _decode_motion(self, x):
        if self.pose_tokenize:
            if x.dim() == 2:
                x = x.unsqueeze(0)
            return self.pose_tokenizer.decode(x).squeeze(0)
        else:
            return x

    def _forward_dec(
        self,
        instructions: list[str],
        responses: list[str],
        audio_tokens: Tensor,
        audio_lengths: list[int],
        motion_tokens: Tensor,
        lengths_poses: list[int],
        **kwargs,
    ):
        mot_input_embs = kwargs["poses"]

        self.language_tokenizer.padding_side = "right"

        outputs = []

        # Step 1: Construct the output string
        outputs, mot_outputs = self._construct_outputs(
            instructions,
            responses,
            audio_tokens,
            audio_lengths,
            motion_tokens,
            lengths_poses,
        )

        # Step 2: Generate labels
        labels = self._apply_chat_template(
            instructions, outputs, behaviour_output=False
        )
        mot_labels = self._apply_chat_template(
            instructions, mot_outputs, behaviour_output=False
        )

        # Step 3: Tokenization
        tokenized_inputs = self.language_tokenizer(
            labels,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        ).to(audio_tokens.device)
        tokenized_input_ids = tokenized_inputs.input_ids
        tokenized_input_ids[tokenized_input_ids == self.m_special_id] = -100
        tokenized_attention_mask = tokenized_inputs.attention_mask

        mot_tokenized_inputs = self.language_tokenizer(
            mot_labels,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        ).to(audio_tokens.device)
        mot_input_ids = mot_tokenized_inputs.input_ids
        mot_input_ids[mot_input_ids == self.m_special_id] = -100

        # Step 4: Get audio and response token positions
        audio_bos_positions, audio_eos_positions = self.get_audio_token_positions(
            tokenized_input_ids
        )
        response_bos_positions, response_eos_positions = (
            self.get_response_token_positions(tokenized_input_ids)
        )

        # Step 5: Prepare labels for loss calculation
        label_ids = tokenized_input_ids.clone()
        for i in range(len(label_ids)):
            label_ids[i, : response_bos_positions[i]] = -100
            # Mask padding token with -100 except for first padding token which is also eos token
            first_pad_idx = (
                label_ids[i] == self.language_tokenizer.pad_token_id
            ).nonzero(as_tuple=True)[0]
            if first_pad_idx.shape[0] > 0:
                label_ids[i, first_pad_idx[0] + 1 :] = -100

        for b in range(len(mot_input_ids)):
            motion_idx = 0
            i = 0
            while i < len(mot_input_ids[b]):
                if mot_input_ids[b][i] == self.m_pad_id:
                    mot_input_ids[b][i] = motion_tokens[b][motion_idx]
                    motion_idx += 1
                else:
                    mot_input_ids[b][i] = -100
                i += 1

        # Step6: Pad mot_input_embs based on mot_input_ids
        # removing all mot_embedding
        padded_mot_input_embs = torch.zeros(
            mot_input_ids.shape[0],
            mot_input_ids.shape[1],
            mot_input_embs.shape[2],
            device=mot_input_embs.device,
            dtype=mot_input_embs.dtype,
        )

        for b in range(len(mot_input_ids)):
            valid_indices = mot_input_ids[b] != -100
            # assert (valid_indices.sum() == lengths_poses[b]), (
            assert valid_indices.sum() <= lengths_poses[b], (
                f"valid_index {valid_indices.sum()} is not match length {lengths_poses[b]}"
            )
            padded_mot_input_embs[b, valid_indices] = mot_input_embs[
                b, : valid_indices.sum()
            ]

        label_ids = [label_ids, mot_input_ids]
        tokenized_input_ids[tokenized_input_ids == 151329] = IGNORE_INDEX  # padding id
        valid_pos = [torch.ne(tokenized_input_ids, IGNORE_INDEX)]
        valid_pos.extend([torch.ne(mot_input_ids, IGNORE_INDEX)])

        # Step 6: Language model forward pass
        with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            outputs = self.language_model(
                input_ids=tokenized_input_ids,
                mot_input_ids=[mot_input_ids],
                mot_input_embs=[padded_mot_input_embs],
                attention_mask=tokenized_attention_mask,
                labels=label_ids,
                output_hidden_states=True,
            )

        # Step 7: Return results
        return_dict = {
            "all_audio_bos_positions": audio_bos_positions,
            "all_audio_eos_positions": audio_eos_positions,
            "loss": outputs.loss,
        }

        return return_dict

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _generate_dec(
        self,
        captions: list[str],
        responses: list[str],
        audio_tokens: Tensor,
        audio_lengths: list[int],
        motion_tokens: Tensor,
        lengths_poses: list[int],
        motion_embs: Tensor,
    ):
        self.language_tokenizer.padding_side = "left"

        # count the number of inference
        self.task_id += 1
        if self.eval_training:
            label_outputs, mot_outputs = self._construct_outputs(
                captions,
                responses,
                audio_tokens,
                audio_lengths,
                motion_tokens,
                lengths_poses,
            )
            labels = self._apply_chat_template(
                captions, label_outputs, generation_prefix=True
            )
            labels = self.language_tokenizer(
                labels,
                padding=True,
                max_length=self.max_length,
                truncation=True,
                return_tensors="pt",
                # add_special_tokens=False,
            ).input_ids
            labels[labels == self.m_special_id] = -100
            outputs = ["" for _ in range(len(captions))]
            motion_embs = motion_embs
        else:
            labels = None
            # Split the tokens into 1s
            if isinstance(self.task, ListConfig):
                self.task = list(self.task)
            if isinstance(self.task, str):
                max_input_length = self.max_length
                if self.task == "uncond":
                    outputs = ["" for _ in range(len(captions))]
                elif self.task == "audio_prefix":
                    crop_audio_tokens = audio_tokens[:, : self.interleave_window]
                    outputs = []
                    for i in range(crop_audio_tokens.shape[0]):
                        outputs.append(
                            self.audio_tokens_to_string(crop_audio_tokens[i])
                        )
                elif os.path.exists(self.task):
                    with open(self.task, "r") as f:
                        outputs = f.readlines()
                    captions = [o.strip() for o in outputs]
                    outputs = ["" for _ in range(len(captions))]
                    self.task = captions
                else:
                    raise ValueError(f"Unknown task: {self.task}")
            elif isinstance(self.task, list):
                captions = self.task
                outputs = ["" for _ in range(len(captions))]
                max_input_length = self.text_interleave_window
            else:
                raise ValueError(f"Unknown task type: {self.task}")

            # change to singel batch
            captions = [captions[self.task_id]]
            motion_embs = None

        prompt = self._apply_chat_template(captions, outputs, generation_prefix=True)

        tokenized_inputs = self.language_tokenizer(
            prompt,
            padding=True,
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
            # add_special_tokens=False,
        )

        # Skip[gMASK]<sop>
        tokenized_input_ids = tokenized_inputs.input_ids

        input_ids = tokenized_input_ids.to(audio_tokens.device)

        outputs_string_lst = []
        audio_tokens_out = []
        motion_tokens_out = []

        mot_input_ids = [
            torch.full_like(
                input_ids,
                IGNORE_INDEX,
                dtype=torch.long,
                device=input_ids.device,
            )
        ]

        with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            outputs = self.language_model.generate(
                input_ids,
                mot_input_ids=mot_input_ids,
                labels=labels,
                pad_token_id=self.language_tokenizer.pad_token_id,
                do_sample=True,
                max_new_tokens=self.max_length,
                output_hidden_states=True,
                return_dict_in_generate=True,
                motion_embs=motion_embs,
                eval_first_chunk=self.config.eval.get("eval_first_chunk", False),
            )

        # for output_string_ids in outputs.sequences:
        #     # output_string_ids = output_string_ids[output_string_ids < self.m_sos_id]
        #     # response_bos_positions, response_eos_positions = (
        #     #     self.get_response_token_positions(output_string_ids.unsqueeze(0))
        #     # )
        #     # output_string_ids = output_string_ids[response_bos_positions[0] :]
        #     outputs_string = self.language_tokenizer.batch_decode(
        #         output_string_ids.unsqueeze(0),
        #         skip_special_tokens=True,
        #         spaces_between_special_tokens=False,
        #     )
        #     # outputs_string = outputs_string[0].replace("streaming_transcription\n", "")
        #     outputs_string_lst.append(outputs_string)

        outputs_string_lst = self.language_tokenizer.batch_decode(
            outputs.sequences,
            skip_special_tokens=True,
            spaces_between_special_tokens=False,
        )
        motion_out = []
        # TODO: just one modality now
        for i in range(outputs.sequences.shape[0]):
            audio_tokens_i = self.extract_audio_tokens(outputs.sequences[i])
            audio_tokens_out.append(audio_tokens_i)

            motion_tokens_i = outputs.mot_output_ids[0][i]
            motion_i = outputs.mot_output_embs[0][i]
            print(motion_i)
            print(motion_i.shape)
            motion_i = motion_i[motion_tokens_i != IGNORE_INDEX].reshape(
                -1, self.nfeats
            )

            motion_out.append(motion_i)

        return outputs_string_lst, audio_tokens_out, motion_out

    def forward_train(
        self,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        z = []
        results = self._forward_dec(
            audio_tokens=audio_tokens,
            audio_lengths=audio_lengths,
            **kwargs,
        )

        lm_loss = results["loss"]

        return {
            "loss": lm_loss,
            "preds": None,
            "labels": None,
            "lm_loss": lm_loss,
        }

    def forward_eval(
        self,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        empty_lst = [""] * len(audio_lengths)

        motion_tokens = kwargs["motion_tokens"]
        motion_embs = kwargs["poses"]
        lengths_poses = kwargs["lengths_poses"]

        if self.mode == "train":
            return {}

        forward_dict = self.forward_train(
            audio_tokens,
            audio_lengths,
            **kwargs,
        )

        output_strs, audio_tokens_out, motion_out = self._generate_dec(
            kwargs["instructions"],
            kwargs["responses"],
            audio_tokens,
            audio_lengths,
            motion_tokens,
            lengths_poses,
            motion_embs,
        )

        # Get unpadded length
        motion_pred_lengths = [len(motion_out[i]) for i in range(len(motion_out))]

        # Get unpadded length
        audio_pred_lengths = [
            len(audio_tokens_out[i]) for i in range(len(audio_tokens_out))
        ]

        forward_dict.update(
            {
                "outputs": output_strs,
                "motion_preds": motion_out,
                "motion_pred_lengths": motion_pred_lengths,
                "audio_preds": audio_tokens_out,
                "audio_labels": audio_tokens,
                "audio_pred_lengths": audio_pred_lengths,
            }
        )

        if isinstance(self.task, list):
            forward_dict.update(
                {
                    "should_stop": True,
                }
            )

        return forward_dict

    def forward(self, **kwargs):
        """
        General forward method that directs to either training or evaluation mode based on the current state.

        Args:
            poses: Input pose sequences.
            lengths: Length of each pose sequence.
            ignore_mask: Mask to ignore specific parts of the sequence during loss calculation.

        Returns:
            Output dictionary from either forward_train or forward_eval method.
        """
        if self.training:
            return_dict = self.forward_train(**kwargs)
        else:
            return_dict = self.forward_eval(**kwargs)

        return_dict.update(kwargs)
        return return_dict

    def training_step(self, batch, batch_idx):
        """
        Perform a single training step, logging relevant metrics.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Loss value for the current training step.
        """
        # for name, param in self.named_parameters():
        #     if param.grad is None:
        #         print(f"Layer '{name}' has no gradient (grad is None).")
        if batch is None:
            return torch.tensor(0.0, device=batch["poses"].device)

        output = self(**batch)
        self.log_progress_bar(
            {
                "lr": self.trainer.optimizers[0].param_groups[0]["lr"],
                "loss": output["loss"],
            }
        )
        log_dict = {"train/loss": output["loss"]}
        self.log_dict(
            log_dict,
            on_step=False,
            on_epoch=True,
            batch_size=len(batch["lengths_poses"]),
            sync_dist=True,
        )
        return output["loss"]

    def validation_step(self, batch, batch_idx):
        """
        Perform a single validation step, logging relevant metrics and triggering visualization if required.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """

        if batch is None:
            return None

        if self.trainer.sanity_checking:
            return None

        if hasattr(self.config.model, "skip_batch"):
            if batch_idx < self.config.model.skip_batch:
                return None

        if hasattr(self.config.model, "skip_videos"):
            video_path = self.config.model.skip_videos
            videos = os.listdir(video_path)
            videos = [fname.split(".")[0] for fname in videos]
            video_name = os.path.basename(batch["video_path"][0]).split(".")[0]
            if video_name in videos:
                print(f"Skip video {video_name}")
                return None
        try:
            output = self(**batch)
        except:
            return None
        try:
            metric = {}

            self.log_dict(
                metric,
                on_epoch=True,
                batch_size=len(batch["lengths_poses"]),
                sync_dist=True,
            )

            if self.trainer.sanity_checking or rank_zero_only.rank != 0:
                return output

            global_step = self.global_step

            save_dir = os.path.join(
                self.hparams.output_dir, "samples", f"step-{global_step}"
            )
            os.makedirs(save_dir, exist_ok=True)
            print(f"Saving samples to {save_dir}")

            for i in range(len(output["audio_preds"])):
                if self.eval_training:
                    tts_speech = self.decode_audio(output["audio_labels"][i])
                    video_name = batch["video_path"][i].split("/")[-1].split(".")[0]
                else:
                    tts_speech = self.decode_audio(output["audio_preds"][i])
                    video_name = f"{batch_idx}_{i}"

                gt_tts_speech = self.decode_audio(output["audio_labels"][i])

                # Decode audio tokens
                with open(os.path.join(save_dir, f"{video_name}.wav"), "wb") as f:
                    torchaudio.save(f, tts_speech, 22050, format="wav")

                with open(os.path.join(save_dir, f"{video_name}_gt.wav"), "wb") as f:
                    torchaudio.save(f, gt_tts_speech, 22050, format="wav")

                # Save motion tokens into pkl
                motion_out = output["motion_preds"][i]
                motion_out = motion_out.cpu().numpy()
                # motion_out = motion_out.reshape(-1)
                npy_file = os.path.join(save_dir, f"{video_name}.npy")
                with open(npy_file, "wb") as f:
                    np.save(f, motion_out)
                gt_npy_file = os.path.join(save_dir, f"{video_name}_gt.npy")
                gt_motion = batch["poses"][i].cpu().numpy()
                with open(gt_npy_file, "wb") as f:
                    np.save(f, gt_motion)

                # Save strings
                with open(os.path.join(save_dir, f"{video_name}.txt"), "w") as f:
                    f.write(output["outputs"][i])

                with open(os.path.join(save_dir, f"{video_name}_gt.txt"), "w") as f:
                    f.write(batch["instructions"][i])
                    f.write("\n")
                    f.write(batch["responses"][i])

            if "should_stop" in output and output["should_stop"]:
                self.trainer.should_stop = True

            # check config.eval  exist or not
            if self.task_id >= self.config.eval.get("num_samples", 5):
                render_script = self.config.eval.render_path
                # run rendering script
                print("Starting rendering script")
                os.system(f"bash {render_script} {save_dir}")

                self.trainer.should_stop = True
        except Exception as e:
            print(f"Error in validation step: {e}")
            # print(f"Batch: {batch}")
            return None
        return output

    def test_step(self, batch, batch_idx):
        """
        Perform a single test step, leveraging the validation step logic.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """
        return self.validation_step(batch, batch_idx)


# ----------------------------- Trainer ----------------------------- #messages = []
class humanexpertTrainer(HFMTrainer):
    def entrypoint(self):
        super().entrypoint()
        if self.config.get("mode", "train") == "train":
            self.fit()
        elif self.config.mode == "eval":
            self.test()
        else:
            raise ValueError(f"Unknown mode: {self.config.mode}")

    def configure_model(self):
        """
        Configure the model for training.
        """
        if hasattr(self.config.model, "pose_tokenizer"):
            config_pose_tokenizer = self.config.model.pose_tokenizer.copy()
            OmegaConf.set_readonly(config_pose_tokenizer, False)
            config_pose_tokenizer.nfeats = self.nfeats
            config_pose_tokenizer.window_size = self.config.data.get("window_size", 64)
        else:
            config_pose_tokenizer = None

        config_audio_tokenizer = self.config.model.audio_tokenizer.copy()
        config_language_model = self.config.model.language_model.copy()

        self.model = humanexpertS3(
            nfeats=self.nfeats,
            config_pose_tokenizer=config_pose_tokenizer,
            config_audio_tokenizer=config_audio_tokenizer,
            config_language_model=config_language_model,
            config=self.config,
            output_dir=self.dir_params["output_dir"],
            eval_dataset=self.eval_dataset,
            train_dataset=self.train_dataset,
        )
        # Load the model from the given path
        self.load_model()
        # print(self.model)
