import os

import numpy as np
import torch
import torchaudio
import sys
from omegaconf import ListConfig, OmegaConf
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_info
from torch import Tensor
from torch.amp.autocast_mode import autocast
from torch.nn.utils.rnn import pad_sequence

from common.config import create_object
from modules.mar.l2loss import L2Loss

from .humanexpert_s1 import humanexpertS1
from .humanexpert_s1 import humanexpertTrainer as HFMTrainer
from .utils import freeze_module, load_model_from_checkpoint


# ----------------------------- Pipeline ----------------------------- #
class humanexpertS3(humanexpertS1):
    def __init__(
        self,
        **kwargs,
    ):
        if "not_loading" not in kwargs:
            super().__init__(not_loading=True, **kwargs)
        else:
            super().__init__(**kwargs)

        if "config_pose_tokenizer" in kwargs and kwargs["config_pose_tokenizer"]:
            self.pose_tokenize = True
            self._load_pose_tokenizer(kwargs["config_pose_tokenizer"])

        else:
            self.pose_tokenize = False

        self.freeze_lm = self.config.model.get("freeze_lm", False)
        self.load_lm = self.config.model.get("load_lm", True)
        self.eval_training = self.config.model.get("eval_training", False)
        config_diff_head = self.config.model.diff_head.copy()
        self._load_diff_head(config_diff_head)

        # FIXME: hardcoded
        self.eos = "<|user|>"
        self.eos_id = self.language_tokenizer.convert_tokens_to_ids(self.eos)
        self.m_sos = "<|reserved_151360|>"
        self.m_sos_id = self.language_tokenizer.convert_tokens_to_ids(self.m_sos)
        self.m_eos = "<|reserved_151361|>"
        self.m_eos_id = self.language_tokenizer.convert_tokens_to_ids(self.m_eos)
        self.m_pad = "<|reserved_151362|>"
        self.m_pad_id = self.language_tokenizer.convert_tokens_to_ids(self.m_pad)
        self.m_special = "<|reserved_151357|>"
        self.m_special_id = self.language_tokenizer.convert_tokens_to_ids(
            self.m_special
        )

        if "not_loading" not in kwargs and hasattr(self.config.model, "model_path"):
            self.load_state_dict(
                load_model_from_checkpoint(
                    self.config.model.model_path, self.state_dict()
                )
            )

        self.hidden_dict = {}
        if self.config.model.get("hidden_dict_path", None) and self.mode == "train":
            rank_zero_info(
                f"Loading hidden dict from {self.config.model.hidden_dict_path}"
            )
            self.hidden_dict = torch.load(
                self.config.model.hidden_dict_path, weights_only=False
            )
            self.language_model = None

        if not self.load_lm:
            self.language_model = None

    def _load_pose_tokenizer(self, config_pose_tokenizer):
        self.pose_tokenizer = create_object(config_pose_tokenizer)
        self.pose_tokenizer.load_state_dict(
            load_model_from_checkpoint(
                config_pose_tokenizer.model_path,
                self.pose_tokenizer.state_dict(),
                "pose_tokenizer",
            )
        )
        freeze_module(self.pose_tokenizer)

    def _load_diff_head(self, config_diff_head):
        # Define diff loss
        self.diff_hidden_layers = config_diff_head.hidden_state_layers
        self.diff_batch_mul = config_diff_head.get("diff_batch_mul", 1)
        self.diff_dropout = config_diff_head.get("dropout", 0.0)
        self.diff_num_sampling_steps = config_diff_head.get("num_sampling_steps", "100")
        self.diff_z_channels = self.lm_hidden_size * self.diff_hidden_layers
        self.cfg = config_diff_head.get("cfg", 1.0)
        head_type = config_diff_head.get("type", "diff")

        if config_diff_head.get("matcha_path", None):
            sys.path.insert(0, config_diff_head.matcha_path)
        if config_diff_head.get("cosyvoice_path", None):
            sys.path.insert(0, config_diff_head.cosyvoice_path)

        if self.pose_tokenize:
            target_channels = self.pose_tokenizer.hparams.width
        else:
            target_channels = self.nfeats

        if head_type in ["diffuser", "mld"]:
            if head_type == "diffuser":
                from modules.mar.diffloss_diffuser import DiffLoss
            elif head_type == "mld":
                from modules.mar.diffloss_mld import DiffLoss
            else:
                raise ValueError(f"Unknown head type: {head_type}")

            self.diff_loss = DiffLoss(
                target_channels=target_channels,
                z_channels=self.diff_z_channels,
                depth=config_diff_head.depth,
                width=config_diff_head.width,
                upsample_rate=int(self.motion_fps / self.audio_fps),
                num_sampling_steps=self.diff_num_sampling_steps,
                train_scheduler=config_diff_head.get("train_scheduler", False),
                gen_scheduler=config_diff_head.get("gen_scheduler", False),
                pe_type=config_diff_head.get("pe_type", "learnable"),
                predict_epsilon=config_diff_head.get("predict_epsilon", False),
                grad_checkpointing=config_diff_head.grad_checkpointing,
                model_type=config_diff_head.get("model_type", "mlp"),
                max_len=config_diff_head.get("max_len", 500),
                dropout=self.diff_dropout,
                causal=config_diff_head.get("causal", False),
            )
            return

        if "flow" in head_type:
            self.head_type = "flow"

            if head_type == "flow":
                from modules.mar.flow import FlowLoss
            elif head_type == "flow_decoder":
                from modules.mar.flow_decoder import FlowLoss
            else:
                raise ValueError(f"Unknown head type: {head_type}")

            self.diff_loss = FlowLoss(
                target_channels=self.nfeats,
                width=config_diff_head.channels[0],
                z_channels=self.diff_z_channels,
                channels=config_diff_head.channels,
                num_sampling_steps=config_diff_head.num_sampling_steps,
                cfm_params=config_diff_head.cfm_params,
                causal=config_diff_head.get("causal", True),
                dropout=config_diff_head.get("dropout", 0.0),
                attention_head_dim=config_diff_head.get("attention_head_dim", 64),
                n_blocks=config_diff_head.get("n_blocks", 4),
                num_mid_blocks=config_diff_head.get("num_mid_blocks", 12),
                num_heads=config_diff_head.get("num_heads", 8),
                act_fn=config_diff_head.get("act_fn", "gelu"),
                norm_type=config_diff_head.get("norm_type", "layer_norm"),
                upsample_rate=int(self.motion_fps / self.audio_fps),
            )
            return

        if head_type == "diff":
            from modules.mar.diffloss import DiffLoss

            head_class = DiffLoss
            self.head_type = "diff"
        else:
            head_class = L2Loss
            self.head_type = "l2"

        if self.pose_tokenize:
            target_channels = self.pose_tokenizer.hparams.width
        else:
            target_channels = self.nfeats

        self.diff_loss = head_class(
            target_channels=target_channels,
            z_channels=self.diff_z_channels,
            depth=config_diff_head.depth,
            width=config_diff_head.width,
            model_type=config_diff_head.get("model_type", "mlp"),
            pe_type=config_diff_head.get("pe_type", "learnable"),
            noise_schedule=config_diff_head.get("noise_schedule", "cosine"),
            noise_schedule_params=config_diff_head.get("noise_schedule_params", None),
            num_sampling_steps=self.diff_num_sampling_steps,
            grad_checkpointing=config_diff_head.grad_checkpointing,
            upsample_rate=int(self.motion_fps / self.audio_fps),
            up_type=config_diff_head.get("up_type", "conv"),
            dropout=self.diff_dropout,
        )

    def _extract_hidden_states(
        self, hidden_states, audio_bos_positions, audio_eos_positions, labels=None
    ):
        hidden_states_lst = []
        for layer_hidden_states in hidden_states[-self.diff_hidden_layers :]:
            batch_hidden_states = []
            for b in range(layer_hidden_states.shape[0]):
                segments = []
                for start, end in zip(audio_bos_positions[b], audio_eos_positions[b]):
                    hidden_seg = layer_hidden_states[b, start:end]

                    while hidden_seg.shape[0] > self.audio_interleave_window:
                        segments.append(hidden_seg[: self.audio_interleave_window])
                        hidden_seg = hidden_seg[self.audio_interleave_window :]

                    if hidden_seg.shape[0] > 0:
                        segments.append(hidden_seg)

                batch_hidden_states.append(
                    torch.cat(segments, dim=0) if segments else torch.tensor([])
                )
            hidden_states_lst.append(
                pad_sequence(batch_hidden_states, batch_first=True)
            )
        return torch.cat(hidden_states_lst, dim=-1)

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _encode_motion(self, poses, lengths):
        if self.pose_tokenize:
            return self.pose_tokenizer.encode(poses, lengths=lengths, mask=None)
        else:
            return poses

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _decode_motion(self, x):
        if self.pose_tokenize:
            if x.dim() == 2:
                x = x.unsqueeze(0)
            return self.pose_tokenizer.decode(x).squeeze(0)
        else:
            return x

    def _forward_dec(
        self,
        instructions: list[str],
        responses: list[str],
        audio_tokens: Tensor,
        audio_lengths: list[int],
        **kwargs,
    ):
        self.language_tokenizer.padding_side = "right"

        outputs = []

        # Step 1: Construct the output string
        outputs = self._construct_outputs(
            instructions, responses, audio_tokens, audio_lengths
        )

        # Step 2: Generate labels
        labels = self._apply_chat_template(
            instructions, outputs, behaviour_output=False
        )

        # Step 3: Tokenization
        tokenized_inputs = self.language_tokenizer(
            labels,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        ).to(audio_tokens.device)
        tokenized_input_ids = tokenized_inputs.input_ids
        tokenized_attention_mask = tokenized_inputs.attention_mask

        # Step 4: Get audio and response token positions
        audio_bos_positions, audio_eos_positions = self.get_audio_token_positions(
            tokenized_input_ids
        )
        response_bos_positions, response_eos_positions = (
            self.get_response_token_positions(tokenized_input_ids)
        )

        # Step 5: Prepare labels for loss calculation
        label_ids = tokenized_input_ids.clone()
        for i in range(len(label_ids)):
            label_ids[i, : response_bos_positions[i]] = -100
            # Mask padding token with -100 except for first padding token which is also eos token
            first_pad_idx = (
                label_ids[i] == self.language_tokenizer.pad_token_id
            ).nonzero(as_tuple=True)[0]
            if first_pad_idx.shape[0] > 0:
                label_ids[i, first_pad_idx[0] + 1 :] = -100

        if self.only_loss_motion:
            label_ids[(label_ids < self.m_sos_id) | (label_ids > self.m_eos_id)] = -100

        # Step 6: Language model forward pass
        with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
            outputs = self.language_model(
                input_ids=tokenized_input_ids,
                attention_mask=tokenized_attention_mask,
                labels=label_ids,
                output_hidden_states=True,
            )

        hidden_states = self._extract_hidden_states(
            outputs.hidden_states, audio_bos_positions, audio_eos_positions, label_ids
        )

        # Step 7: Return results
        return_dict = {
            "all_audio_bos_positions": audio_bos_positions,
            "all_audio_eos_positions": audio_eos_positions,
            "hidden_states": hidden_states,
            "loss": outputs.loss,
        }

        return return_dict

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _generate_dec(
        self,
        captions: list[str],
        audio_tokens: Tensor,
        audio_lengths: list[int],
    ):
        self.language_tokenizer.padding_side = "left"
        # Split the tokens into 1s
        if isinstance(self.task, ListConfig):
            self.task = list(self.task)
        if isinstance(self.task, str):
            max_input_length = self.max_length
            if self.task == "uncond":
                outputs = ["" for _ in range(len(captions))]
            elif self.task == "audio_prefix":
                crop_audio_tokens = audio_tokens[:, : self.interleave_window]
                outputs = []
                for i in range(crop_audio_tokens.shape[0]):
                    outputs.append(self.audio_tokens_to_string(crop_audio_tokens[i]))
            elif os.path.exists(self.task):
                with open(self.task, "r") as f:
                    outputs = f.readlines()
                captions = [o.strip() for o in outputs]
                outputs = ["" for _ in range(len(captions))]
                self.task = captions
            else:
                raise ValueError(f"Unknown task: {self.task}")
        elif isinstance(self.task, list):
            captions = self.task
            outputs = ["" for _ in range(len(captions))]
            max_input_length = self.text_interleave_window
        else:
            raise ValueError(f"Unknown task type: {self.task}")

        labels = self._apply_chat_template(captions, outputs, generation_prefix=True)

        tokenized_inputs = self.language_tokenizer(
            labels,
            padding=True,
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
            # add_special_tokens=False,
        )

        # Skip[gMASK]<sop>
        tokenized_input_ids = tokenized_inputs.input_ids[:, 2:]

        input_ids = tokenized_input_ids.to(audio_tokens.device)

        outputs_string_lst = []
        hidden_state_lst = []
        audio_tokens_out = []
        audio_bos_position_lst = []
        audio_eos_position_lst = []
        for seg in range(0, input_ids.shape[0], 32):
            with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
                outputs = self.language_model.generate(
                    input_ids[seg : seg + 32],
                    # attention_mask=attention_mask,
                    pad_token_id=self.language_tokenizer.pad_token_id,
                    do_sample=True,
                    max_new_tokens=self.max_length,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                    # max_length=self.max_length,
                )

            for output_string_ids in outputs.sequences:
                output_string_ids = output_string_ids[output_string_ids < self.m_sos_id]
                response_bos_positions, response_eos_positions = (
                    self.get_response_token_positions(output_string_ids.unsqueeze(0))
                )
                output_string_ids = output_string_ids[response_bos_positions[0] :]
                outputs_string = self.language_tokenizer.batch_decode(
                    output_string_ids.unsqueeze(0),
                    skip_special_tokens=True,
                    spaces_between_special_tokens=False,
                )
                outputs_string = outputs_string[0].replace(
                    "streaming_transcription\n", ""
                )

                outputs_string_lst.append(outputs_string)

            audio_bos_positions, audio_eos_positions = self.get_audio_token_positions(
                outputs.sequences
            )
            audio_bos_position_lst.extend(audio_bos_positions)
            audio_eos_position_lst.extend(audio_eos_positions)

            tmp_list = []
            for j in range(len(outputs.hidden_states[0])):
                all_hidden_states_per_layer = []
                for i in range(len(outputs.hidden_states)):
                    all_hidden_states_per_layer.append(outputs.hidden_states[i][j])
                # -> (batch, sum_of_token_lengths_in_group, hidden_size)
                hidden_states = torch.cat(all_hidden_states_per_layer, dim=1)
                tmp_list.append(hidden_states)

            hidden_states = torch.stack(tmp_list, dim=0)[-self.diff_hidden_layers :]

            hidden_state_lst.append(hidden_states)

            for i in range(outputs.sequences.shape[0]):
                audio_tokens_i = self.extract_audio_tokens(outputs.sequences[i])
                audio_tokens_out.append(audio_tokens_i)

        hidden_states = torch.cat(hidden_state_lst, dim=1)

        hidden_states = self._extract_hidden_states(
            hidden_states, audio_bos_position_lst, audio_eos_position_lst
        )

        return outputs_string_lst, hidden_states, audio_tokens_out

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _forward_diff(self, target, z, lengths):
        # Crop target to match z
        target = target[:, : z.shape[1] * self.diff_loss.upsample_rate]
        z = z[:, : target.shape[1] // self.diff_loss.upsample_rate]
        target = target[:, : z.shape[1] * self.diff_loss.upsample_rate]
        lengths = [min(l, target.shape[1]) for l in lengths]

        # Repeat inputs
        target = target.repeat(self.diff_batch_mul, 1, 1)
        z = z.repeat(self.diff_batch_mul, 1, 1)
        lengths = lengths * self.diff_batch_mul

        # Compute diff loss for this segment
        loss_diff_head = self.diff_loss(target=target, z=z, lengths=lengths)

        return loss_diff_head

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
    def _generate_diff(self, hidden_states):
        if hidden_states.shape[1] > 250:
            hidden_states = hidden_states[:, :250]
        lengths = [
            hidden_states.shape[1] * self.diff_loss.upsample_rate
        ] * hidden_states.shape[0]
        sampled = self.diff_loss.sample(hidden_states, lengths, cfg=self.cfg)
        if self.pose_tokenize:
            sampled = self.pose_tokenizer.decode(sampled)
        return sampled

    def forward_train(
        self,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        z = []
        responses = kwargs["responses"]

        for i, res in enumerate(responses):
            if res in self.hidden_dict:
                z.append(self.hidden_dict[res])
            elif "hidden_states" in kwargs:
                z = kwargs["hidden_states"]
            elif "hidden_states_paths" in kwargs:
                hidden_path = kwargs["hidden_states_paths"][i]
                if os.path.exists(hidden_path):
                    try:
                        hidden_states = torch.load(hidden_path)
                        z.append(hidden_states)
                    except Exception as e:
                        rank_zero_info(
                            f"Error loading hidden states from {hidden_path}:e"
                        )
            elif self.mode == "train":
                rank_zero_info(f"Missing hidden state for {res}")

        if len(z) < len(responses):
            if not self.load_lm:
                total_loss = torch.tensor(0.0, device=kwargs["poses"].device)
                return {
                    "loss": total_loss,
                    "preds": None,
                    "labels": None,
                    "lm_loss": None,
                    "diff_loss": total_loss,
                }

            results = self._forward_dec(
                audio_tokens=audio_tokens,
                audio_lengths=audio_lengths,
                **kwargs,
            )

            z = results["hidden_states"]
            lm_loss = results["loss"]

            if self.hidden_dict and self.mode == "train":
                for b in range(z.shape[0]):
                    self.hidden_dict[responses[b]] = (
                        z[b][: audio_lengths[b]].clone().detach().cpu()
                    )

            if "hidden_states_paths" in kwargs and self.mode == "train":
                for i, hidden_path in enumerate(kwargs["hidden_states_paths"]):
                    if not os.path.exists(hidden_path):
                        rank_zero_info(f"Saving hidden states to {hidden_path}")
                        torch.save(z[i][: audio_lengths[i]].detach().cpu(), hidden_path)

        else:
            z = pad_sequence(z, batch_first=True).to(audio_tokens.device)
            lm_loss = None

        if self.pose_tokenize:
            kwargs["poses"] = self._encode_motion(
                kwargs["poses"], lengths=kwargs["lengths_poses"]
            )

        diff_loss = self._forward_diff(kwargs["poses"], z, kwargs["lengths_poses"])

        if self.freeze_lm:
            total_loss = diff_loss
        else:
            total_loss = lm_loss + diff_loss

        return {
            "loss": total_loss,
            "preds": None,
            "labels": None,
            "lm_loss": lm_loss,
            "diff_loss": diff_loss,
            "hidden_states": z,
        }

    def forward_eval(
        self,
        audio_tokens,
        audio_lengths,
        **kwargs,
    ):
        empty_lst = [""] * len(audio_lengths)

        if self.mode == "train":
            return {}

        forward_dict = self.forward_train(
            audio_tokens,
            audio_lengths,
            **kwargs,
        )

        if self.eval_training:
            hidden_states = forward_dict["hidden_states"]
            output_strs = []
            audio_tokens_out = []
            for i in range(len(audio_lengths)):
                output_strs.append(
                    kwargs["instructions"][i] + "\n" + kwargs["responses"][i]
                )
                audio_tokens_out.append(audio_tokens[i][: audio_lengths[i]])
            audio_pred_lengths = audio_lengths
        else:
            output_strs, hidden_states, audio_tokens_out = self._generate_dec(
                empty_lst, audio_tokens, audio_lengths
            )

            # Get unpadded length
            audio_pred_lengths = (hidden_states.abs().sum(dim=-1) != 0).sum(
                dim=1
            )  # (batch,)

        motion_pred_lengths = audio_pred_lengths * int(self.motion_fps / self.audio_fps)
        sampled_token_latent = self._generate_diff(hidden_states)

        forward_dict.update(
            {
                "outputs": output_strs,
                "motion_preds": sampled_token_latent,
                "motion_pred_lengths": motion_pred_lengths,
                "audio_preds": audio_tokens_out,
                "audio_labels": audio_tokens,
                "audio_pred_lengths": audio_pred_lengths,
            }
        )

        if isinstance(self.task, list):
            forward_dict.update(
                {
                    "should_stop": True,
                }
            )

        return forward_dict

    def forward(self, **kwargs):
        """
        General forward method that directs to either training or evaluation mode based on the current state.

        Args:
            poses: Input pose sequences.
            lengths: Length of each pose sequence.
            ignore_mask: Mask to ignore specific parts of the sequence during loss calculation.

        Returns:
            Output dictionary from either forward_train or forward_eval method.
        """
        if self.training:
            return_dict = self.forward_train(**kwargs)
        else:
            return_dict = self.forward_eval(**kwargs)

        return_dict.update(kwargs)
        return return_dict

    def training_step(self, batch, batch_idx):
        """
        Perform a single training step, logging relevant metrics.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Loss value for the current training step.
        """
        # for name, param in self.named_parameters():
        #     if param.grad is None:
        #         print(f"Layer '{name}' has no gradient (grad is None).")

        if batch is None:
            return torch.tensor(0.0, device=batch["poses"].device)

        output = self(**batch)
        self.log_progress_bar(
            {
                "lr": self.trainer.optimizers[0].param_groups[0]["lr"],
                "loss": output["loss"],
            }
        )
        log_dict = {"train/loss": output["loss"]}
        self.log_dict(
            log_dict,
            on_step=False,
            on_epoch=True,
            batch_size=len(batch["lengths_poses"]),
            sync_dist=True,
        )
        return output["loss"]

    def validation_step(self, batch, batch_idx):
        """
        Perform a single validation step, logging relevant metrics and triggering visualization if required.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """
        if batch is None:
            return None

        output = self(**batch)
        metric = {}

        self.log_dict(
            metric,
            on_epoch=True,
            batch_size=len(batch["lengths_poses"]),
            sync_dist=True,
        )

        if self.trainer.sanity_checking or rank_zero_only.rank != 0 or batch_idx > 20:
            return output

        global_step = self.global_step

        save_dir = os.path.join(
            self.hparams.output_dir, "samples", f"step-{global_step}"
        )
        os.makedirs(save_dir, exist_ok=True)
        print(f"Saving samples to {save_dir}")

        for i in range(len(output["audio_preds"])):
            # Decode audio tokens
            tts_speech = self.decode_audio(output["audio_preds"][i])
            with open(os.path.join(save_dir, f"{batch_idx}_{i}_pred.wav"), "wb") as f:
                torchaudio.save(f, tts_speech, 22050, format="wav")

            # Decode motion tokens
            motion_preds = self._decode_motion(output["motion_preds"][i])
            with open(os.path.join(save_dir, f"{batch_idx}_{i}_pred.npy"), "wb") as f:
                np.save(f, motion_preds.float().cpu().numpy())

            # Save strings
            with open(os.path.join(save_dir, f"{batch_idx}_{i}_pred.txt"), "w") as f:
                f.write(output["outputs"][i])

        if "should_stop" in output and output["should_stop"]:
            self.trainer.should_stop = True

        return output

    def test_step(self, batch, batch_idx):
        """
        Perform a single test step, leveraging the validation step logic.

        Args:
            batch: Input batch containing pose data.
            batch_idx: Index of the current batch.

        Returns:
            Output dictionary from the evaluation process.
        """
        return self.validation_step(batch, batch_idx)

    # def on_train_epoch_end(self):
    #     # Get current gpu id
    #     gpu_id = torch.cuda.current_device()
    #     os.makedirs(self.hparams.output_dir, exist_ok=True)
    #     torch.save(
    #         self.hidden_dict,
    #         os.path.join(self.hparams.output_dir, f"hidden_dict_{gpu_id}.pth"),
    #     )


# ----------------------------- Trainer ----------------------------- #messages = []
class humanexpertTrainer(HFMTrainer):
    def entrypoint(self):
        super().entrypoint()
        if self.config.get("mode", "train") == "train":
            self.fit()
        elif self.config.mode == "eval":
            self.test()
        else:
            raise ValueError(f"Unknown mode: {self.config.mode}")

    def configure_datasets(self):
        """
        Configure the training and validation datasets.
        """
        debug = self.config.get("debug", False)
        mean_path = self.config.data.get("mean_path", None)
        std_path = self.config.data.get("std_path", None)

        from data.faces.data_collator import data_collator_token as data_collator

        if self.config.data.get("data_path", None):
            # old data loader
            from data.faces.dataset_tokens import FaceDatasetCrop

            if self.config.get("mode", "train") == "train":
                self.train_dataset = FaceDatasetCrop(
                    data_path=self.config.data.data_path,
                    metadata_path=self.config.data.get("metadata_path", None),
                    split_path=self.config.data.get("split_path", None),
                    instruction_path=self.config.data.get("instruction_path", None),
                    split=self.config.data.get("train_split", "train"),
                    task_path=self.config.data.get("task_path", None),
                    face_token_dir=self.config.data.get("face_token_dir", None),
                    pose_type=self.config.data.get("pose_type", "coco"),
                    load_mode=self.config.data.get("load_mode", "metadata"),
                    body_parts=self.config.data.body_parts,
                    mean_path=mean_path,
                    std_path=std_path,
                    debug=debug,
                    share_index=self.config.data.get("share_index", False),
                    window_size=self.config.data.get("window_size", 2048),
                    rescale=self.config.data.get("rescale", False),
                    confidence=self.config.data.get("confidence", False),
                    face_dir=self.config.data.get(
                        "face_dir",
                        "/face_only_mot/refNet_iclr/20241103-23/video_CLEAN",
                    ),
                    audio_dir=self.config.data.get(
                        "audio_dir",
                        "/face_only_mot/audio_CLEAN",
                    ),
                    locale=self.config.data.get("locale", "sg"),
                )
            else:
                self.train_dataset = None

            self.eval_dataset = FaceDatasetCrop(
                data_path=self.config.data.data_path,
                metadata_path=self.config.data.get("metadata_path", None),
                split_path=self.config.data.get("split_path", None),
                split=self.config.data.get("eval_split", "val"),
                instruction_path=self.config.data.get("instruction_path", None),
                task_path=self.config.data.get("task_path", None),
                face_token_dir=self.config.data.get("face_token_dir", None),
                pose_type=self.config.data.get("pose_type", "coco"),
                load_mode=self.config.data.get("load_mode", "metadata"),
                body_parts=self.config.data.body_parts,
                mean_path=mean_path,
                std_path=std_path,
                debug=debug,
                share_index=self.config.data.get("share_index", False),
                window_size=self.config.data.get("window_size", 2048),
                rescale=self.config.data.get("rescale", False),
                confidence=self.config.data.get("confidence", False),
                face_dir=self.config.data.get(
                    "face_dir",
                    "/face_only_mot/refNet_iclr/20241103-23/video_CLEAN",
                ),
                audio_dir=self.config.data.get(
                    "audio_dir",
                    "/face_only_mot/audio_CLEAN",
                ),
                locale=self.config.data.get("locale", "sg"),
            )
        elif self.config.data.get("type", None) == "dataset_tos":
            # Parquet data loader
            from data.faces.dataset_tos import FaceDatasetCrop

            config_data = self.config.data.copy()
            config_data = OmegaConf.to_container(config_data)
            config_data["split"] = self.config.data.get("train_split", "train")
            config_data["debug"] = self.config.get("debug", False)
            self.train_dataset = FaceDatasetCrop(**config_data)
            config_data["split"] = self.config.data.get("eval_split", "val")
            self.eval_dataset = FaceDatasetCrop(**config_data)

        elif self.config.data.get("type", None) == "dataset_ltc":
            # Parquet data loader
            if self.config.model.get("eval_training", False):
                from data.faces.dataset_ltc import FaceLTCDatasetCrop as FaceDataset
            else:
                from data.faces.dataset_ltc import MergedFaceDataset as FaceDataset

            config_data = self.config.data.copy()
            config_data = OmegaConf.to_container(config_data)
            config_data["split"] = self.config.data.get("train_split", "train")
            config_data["debug"] = self.config.get("debug", False)
            self.train_dataset = FaceDataset(**config_data)
            config_data["split"] = self.config.data.get("eval_split", "val")
            self.eval_dataset = FaceDataset(**config_data)
        elif self.config.data.get("type", None) == "pose_dataset_ltc":
            # Parquet data loader
            from data.poses.dataset_ltc import PoseLTCDatasetCrop

            config_data = self.config.data.copy()
            config_data = OmegaConf.to_container(config_data)
            config_data["split"] = self.config.data.get("train_split", "train")
            config_data["debug"] = self.config.get("debug", False)
            self.train_dataset = PoseLTCDatasetCrop(**config_data)
            config_data["split"] = self.config.data.get("eval_split", "val")
            self.eval_dataset = PoseLTCDatasetCrop(**config_data)

        else:
            # Parquet data loader
            from data.faces.dataset_parquet import FaceDatasetCrop

            config_data = self.config.data.copy()
            config_data = OmegaConf.to_container(config_data)
            config_data["split"] = self.config.data.get("train_split", "train")
            config_data["debug"] = self.config.get("debug", False)
            self.train_dataset = FaceDatasetCrop(**config_data)
            config_data["split"] = self.config.data.get("eval_split", "val")
            self.eval_dataset = FaceDatasetCrop(**config_data)

        self.nfeats = self.eval_dataset.nfeats
        self.data_collator = data_collator

    def configure_model(self):
        """
        Configure the model for training.
        """
        if hasattr(self.config.model, "pose_tokenizer"):
            config_pose_tokenizer = self.config.model.pose_tokenizer.copy()
            OmegaConf.set_readonly(config_pose_tokenizer, False)
            config_pose_tokenizer.nfeats = self.nfeats
            config_pose_tokenizer.window_size = self.config.data.get("window_size", 64)
        else:
            config_pose_tokenizer = None

        config_audio_tokenizer = self.config.model.audio_tokenizer.copy()
        config_language_model = self.config.model.language_model.copy()

        self.model = humanexpertS3(
            nfeats=self.nfeats,
            config_pose_tokenizer=config_pose_tokenizer,
            config_audio_tokenizer=config_audio_tokenizer,
            config_language_model=config_language_model,
            config=self.config,
            output_dir=self.dir_params["output_dir"],
            eval_dataset=self.eval_dataset,
            train_dataset=self.train_dataset,
        )
        # Load the model from the given path
        self.load_model()
        # print(self.model)
