from __future__ import annotations

import sys
sys.path.insert(0, sys.path[0]+r"/../")

import os
import random
import time
from dataclasses import dataclass, asdict, make_dataclass, field

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda import amp
import tyro
import yaml
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from tqdm import tqdm
import copy

from model.mld_vae import AutoMldVae, AutoMldVaeV2, AutoMldVaeWithAdapter, AutoMldVaeWithAdapterV2
from data_loaders.humanml.data.dataset import WeightedPrimitiveSequenceDataset, WeightedPrimitiveSequenceDatasetV2
from data_loaders.humanml.data.dataset_hml3d import HML3dDataset, HML3dD262Dataset
from data_loaders.HHI.data.dataset_interhuman import InterHumanDataset, InterHumanDatasetV2, InterHumanMotion, InterHumanMotionV2, InterHumanDatasetV4, InterGenDataset, InterGenDatasetWPERT
from data_loaders.HHI.data.dataset_interx import InterXDataset
from utils.smpl_utils import get_smplx_param_from_6d
from pytorch3d import transforms
from diffusion.nn import mean_flat, sum_flat

from utils.intergen_util import process_motion_blended
from utils.model_util import CosineWarmupScheduler

debug = 0

kinematic_chain = [[0, 2, 5, 8, 11],
                 [0, 1, 4, 7, 10],
                 [0, 3, 6, 9, 12, 15],
                 [9, 14, 17, 19, 21],
                 [9, 13, 16, 18, 20]]

@dataclass
class VAEArgs:
    type: str = "v1"
    arch: str = "all_encoder"
    ff_size: int = 1024
    num_layers: int = 5
    num_heads: int = 4
    dropout: float = 0.1
    normalize_before: bool = False
    activation: str = "gelu"
    position_embedding: str = "learned"
    latent_dim: tuple[int, int] = (1, 256)
    h_dim: int = 256

    nfeats: int = 0
    """feature dimension, will be auto filled"""

@dataclass
class DataArgs:
    interaction: int = 0
    
    # cfg_path: str = "./config_files/config_hydra/motion_primitive/mp_h2_f8_r1.yaml"
    cfg_path: str = "./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml"
    """motion primitive config file"""
    
    # data_dir: str = "./data/seq_data"
    # data_dir: str = "./data/hml3d_smplh/seq_data_zero_male"
    data_dir: str = "./data/InterHuman/seq_data_zero_male"
    """processed dataset directory"""
    
    # dataset: str = "mp_seq"
    # dataset: str = "hml3d"
    dataset: str = "interhuman"
    """dataset name"""

    prob_static: float = 0.0
    enforce_gender: str = 'male'
    """enforce all data use the specified gender"""

    enforce_zero_beta: int = 1
    """enforce all data use zero shape parameters"""

    weight_scheme: str = 'uniform_samp:0.'
    """weighting schemes determining how motion primitives are sampled during training"""    
    
    text_tolerance: float = 0.0  # accept text labels in near future within some frames
    """accept text labels in near future within some frames"""

    history_length: int = 0
    future_length: int = 0
    num_primitive: int = 0
    feature_dim: int = 0
    """auto filled"""

    # body_type: str = 'smplx'
    body_type: str = 'smplh'
    """body type, 'smplx' or 'smplh'"""
    
    use_indi_text: bool = False
    load_text_embedding: bool = False

@dataclass
class TrainArgs:
    learning_rate: float = 1e-4
    anneal_lr: int = 1
    batch_size: int = 128
    grad_clip: float = 1.0
    
    ema_decay: float = 0.999
    """exponential moving average decay"""
    use_amp: int = 0
    """use automatic mixed precision"""

    stage1_steps: int = 100000
    """training steps for stage 1 without rollout training"""
    stage2_steps: int = 100000
    """training steps for stage 2 with linearly increasing percent of rollout training"""
    stage3_steps: int = 100000
    """training steps for stage 3 with only rollout training"""

    weight_rec: float = 1.0  # vae only
    weight_kl: float = 1e-4  # vae only
    weight_smpl_joints_rec: float = 0.0
    weight_joints_consistency: float = 0.0
    weight_transl_delta: float = 0.0
    weight_orient_delta: float = 0.0
    weight_joints_delta: float = 0.0
    weight_latent_rec: float = 1.0  # denoiser only
    weight_feature_rec: float = 0.0  # denoiser only
    weight_inter_rec: float = 1e-2  # denoiser only
    weight_rel_orient: float = 1e-2
    weight_joint_affinity: float = 1e-2
    weight_distance_map: float = 1e-2
    weight_vel: float = 100.0
    weight_bl: float = 5.0
    weight_fc: float = 500.0
    weight_geo: float = 1e-2

    resume_checkpoint: str | None = None
    log_interval: int = 1000
    val_interval: int = 10000
    save_interval: int = 100000

    use_predicted_joints: int = 0
    """if set to 1, use predicted joints to rollout, otherwise use the regressed joints from smplx body model"""
    
    finetune: bool = field(default=False, metadata={"help": "Whether to finetune the model."})
    finetune_mode: int = 0
    
    mix_hml3d: bool = False
    mix_rate: float = 0.25

@dataclass
class Args:
    train_args: TrainArgs = TrainArgs()
    model_args: VAEArgs = VAEArgs()
    data_args: DataArgs = DataArgs()

    # exp_name: str = "mvae_hml3d"
    exp_name: str = "mvae_interhuman"
    seed: int = 0
    torch_deterministic: bool = True
    device: str = "cuda"
    save_dir: str = ''

    his_mask_prob: float = -1
    
    padding: bool = False

    track: int = 0
    wandb_project_name: str = "mld_vae"
    wandb_entity: str = "Motion_Generation"
    
    warmup_steps: int = 0
    do_eval: bool = False
    eval_every_n_steps: int = 1000


def load_mvae(mvae_ckpt, device):
    mvae_checkpoint_dir = Path(mvae_ckpt).parent
    arg_path = mvae_checkpoint_dir / "args.yaml"
    with open(arg_path, "r") as f:
        mvae_args = tyro.extras.from_yaml(Args, yaml.safe_load(f))
    
    # load mvae model and freeze
    print('vae model args:', asdict(mvae_args.model_args))
    if mvae_args.train_args.finetune:
        vae_model = AutoMldVaeWithAdapter(
            **asdict(mvae_args.model_args),
        ).to(device)
    else:
        if mvae_args.model_args.type == "v1":
            vae_model = AutoMldVae(
                **asdict(mvae_args.model_args),
            ).to(device)
        elif mvae_args.model_args.type == "v2":
            vae_model = AutoMldVaeV2(
                **asdict(mvae_args.model_args),
            ).to(device)
    
    checkpoint = torch.load(mvae_ckpt, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    if 'latent_mean' not in model_state_dict:
        model_state_dict['latent_mean'] = torch.tensor(0)
    if 'latent_std' not in model_state_dict:
        model_state_dict['latent_std'] = torch.tensor(1)
    vae_model.load_state_dict(model_state_dict)
    vae_model.latent_mean = model_state_dict['latent_mean']  # register buffer seems to be not loaded by load_state_dict
    vae_model.latent_std = model_state_dict['latent_std']
    print(f"Loading vae checkpoint from {mvae_ckpt}")
    print(f"latent_mean: {vae_model.latent_mean}")
    print(f"latent_std: {vae_model.latent_std}")
    for param in vae_model.parameters():
        param.requires_grad = False
    vae_model.eval()
    return mvae_args, vae_model


class Trainer:
    def __init__(self, args: Args):
        self.args = args
        args.save_dir = Path('./mvae') / args.exp_name
        args.save_dir.mkdir(parents=True, exist_ok=True)
        train_args = args.train_args
        model_args = args.model_args
        data_args = args.data_args
        data_args.padding = args.padding
        
        self.fids = [7, 10, 8, 11]
        self.feet_thre = 0.001
        self.n_joints = 22
        
        self.use_indi_text = data_args.use_indi_text
        self.load_text_embedding = data_args.load_text_embedding

        # TRY NOT TO MODIFY: seeding
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.set_default_dtype(torch.float32)
        torch.backends.cudnn.deterministic = args.torch_deterministic
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")

        # load dataset
        if data_args.dataset == 'mp_seq_v2':
            dataset_class = WeightedPrimitiveSequenceDatasetV2
        elif data_args.dataset == 'hml3d':
            dataset_class = HML3dDataset
        elif data_args.dataset == 'hml3d_d262':
            dataset_class = HML3dD262Dataset
        elif data_args.dataset == 'interhuman':
            # dataset_class = InterHumanDataset
            dataset_class = InterHumanDatasetV4
            # dataset_class = InterHumanMotion
        elif data_args.dataset == 'interhuman_single':
            dataset_class = InterHumanMotionV2
            # dataset_class = InterHumanDatasetV4
        elif data_args.dataset == 'interhuman_d262':
            dataset_class = InterGenDataset
        elif data_args.dataset == 'interhuman_d262_wpert':
            dataset_class = InterGenDatasetWPERT
        elif data_args.dataset == 'interx':
            dataset_class = InterXDataset
        else:
            dataset_class = WeightedPrimitiveSequenceDataset
        
        self.mode = 'merged' if data_args.interaction else 'sep'
        data_args.enforce_gender = None if data_args.enforce_gender.lower() == 'none' else data_args.enforce_gender
        train_dataset = dataset_class(dataset_path=data_args.data_dir,
                                      dataset_name=data_args.dataset,
                                      cfg_path=data_args.cfg_path, prob_static=data_args.prob_static,
                                      enforce_gender=data_args.enforce_gender,
                                      enforce_zero_beta=data_args.enforce_zero_beta,
                                      body_type=data_args.body_type,
                                      split='train', 
                                      device=device,
                                      weight_scheme=data_args.weight_scheme,
                                      mode=self.mode,
                                      padding=args.padding,
                                      use_indi_text=data_args.use_indi_text,
                                      load_text_embedding=data_args.load_text_embedding,
                                      )
        
        if train_args.mix_hml3d:
            if 'd262' in data_args.dataset:
                hml3d_dataset = HML3dD262Dataset(prob_static=data_args.prob_static,
                                                body_type=data_args.body_type,
                                                split='train', 
                                                device=device,
                                                weight_scheme=data_args.weight_scheme,
                                                )
            else:
                hml3d_dataset = HML3dDataset(prob_static=data_args.prob_static,
                                             body_type=data_args.body_type,
                                             split='train', 
                                             device=device,
                                             weight_scheme=data_args.weight_scheme,
                                             )
            self.hml3d_dataset = hml3d_dataset

        # TODO: why use train_dataset as val_dataset?
        val_dataset = train_dataset
        # if 'text' in data_args.weight_scheme or 'samp:1' in data_args.weight_scheme:
        #     val_dataset = train_dataset
        # else:
        #     val_dataset = dataset_class(dataset_path=data_args.data_dir, dataset_name=data_args.dataset,
        #                                                    cfg_path=data_args.cfg_path, prob_static=data_args.prob_static,
        #                                                    enforce_gender=data_args.enforce_gender,
        #                                                    enforce_zero_beta=data_args.enforce_zero_beta,
        #                                                    split='val', device=device,
        #                                                    weight_scheme=data_args.weight_scheme,
        #                                                    )

        # get primitive configs
        data_args.history_length = train_dataset.history_length
        data_args.future_length = train_dataset.future_length
        data_args.num_primitive = train_dataset.num_primitive
        data_args.feature_dim = 0
        for k in train_dataset.motion_repr:
            data_args.feature_dim += train_dataset.motion_repr[k]
        model_args.nfeats = data_args.feature_dim
        if self.args.data_args.interaction:
            # model_args.nfeats *= 2
            data_args.feature_dim *= 2

        with open(args.save_dir / "args.yaml", "w") as f:
            yaml.dump(tyro.extras.to_yaml(args), f)
        with open(args.save_dir / "args_read.yaml", "w") as f:
            yaml.dump(asdict(args), f)
        run_name = f"{args.exp_name}__seed{args.seed}__{int(time.time())}"
        if args.track:
            import wandb
            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                save_code=True,
            )
            wandb.run.log_code(root=".",
                               include_fn=lambda path, root: os.path.relpath(path, root).startswith("mld/") or
                                                             os.path.relpath(path, root).startswith("model/")
                               )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )
        print('model args:', asdict(model_args))
        if train_args.finetune:
            if train_args.finetune_mode == 0:
                model = AutoMldVaeWithAdapter(
                    **asdict(model_args),
                ).to(device)
            elif train_args.finetune_mode == 1:
                model = AutoMldVaeWithAdapterV2(
                    **asdict(model_args),
                ).to(device)
        else:
            if model_args.type == "v1":
                model = AutoMldVae(
                    **asdict(model_args),
                ).to(device)
            elif model_args.type == "v2":
                model = AutoMldVaeV2(
                    **asdict(model_args),
                ).to(device)
        
        optimizer = optim.AdamW(model.parameters(), lr=train_args.learning_rate)
        start_step = 1
        if args.train_args.resume_checkpoint is not None:
            checkpoint = torch.load(args.train_args.resume_checkpoint, map_location=device)
            model_state_dict = checkpoint['model_state_dict']
            model_state_dict = {k:v for k, v in model_state_dict.items() if "adapter" not in k}
            if 'latent_mean' not in model_state_dict:
                model_state_dict['latent_mean'] = torch.tensor(0)
            if 'latent_std' not in model_state_dict:
                model_state_dict['latent_std'] = torch.tensor(1)
            missing, unexpected = model.load_state_dict(model_state_dict, strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f"Loading checkpoint from {args.train_args.resume_checkpoint} at step {start_step}")
            print(f"  >> Missing keys: {missing}")
            print(f"  >> Unexpected keys: {unexpected}")
            if not train_args.finetune:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_step = checkpoint['num_steps'] + 1
        self.model_avg = None
        if args.train_args.ema_decay > 0:
            self.model_avg = copy.deepcopy(model)
            self.model_avg.eval()

        self.model = model
        self.optimizer = optimizer
        self.writer = writer
        self.start_step = start_step
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.device = device
        self.batch_size = train_args.batch_size
        self.step = start_step
        
        self.data_args = data_args
        
        self.warmup_steps = args.warmup_steps
        if self.warmup_steps > 0:
            self.scheduler = CosineWarmupScheduler(
                optimizer,
                warmup=self.warmup_steps,
                max_iters=train_args.stage1_steps + train_args.stage2_steps + train_args.stage3_steps,
            )
        
        self.mix_hml3d = train_args.mix_hml3d
        self.mix_rate = train_args.mix_rate

        # self.rec_criterion = torch.nn.HuberLoss(reduction='mean', delta=1.0)
        self.rec_criterion = torch.nn.HuberLoss(reduction='none', delta=1.0)
        self.transf_rotmat = torch.eye(3, device=self.device).unsqueeze(0)
        self.transf_transl = torch.zeros(3, device=self.device).reshape(1, 1, 3)

    def masked_loss(self, pred, target, valid_mask):
        loss = self.rec_criterion(pred.squeeze(), target.squeeze())  # [B, ..., D]
        loss = loss.mean(dim=list(range(1, loss.ndim)))  # → [B]
        if valid_mask is not None:
            loss = loss[valid_mask]
            return loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=pred.device)
        return loss.mean()

    def calc_loss(self, motion, cond, history_motion, future_motion_gt, future_motion_pred, latent, dist):
        train_args = self.args.train_args
        model_kwargs = cond
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        primitive_padding_mask = cond['y'].get('primitive_padding_mask', None)
        if primitive_padding_mask is not None:
            padding_mask = primitive_padding_mask[:,-1].view(-1)
            valid_mask = ~padding_mask
        else:
            valid_mask = None
            
        if valid_mask is not None and valid_mask.sum() == 0:
            return None
            
        terms = {}

        # kl loss
        mu_ref = torch.zeros_like(dist.loc)
        scale_ref = torch.ones_like(dist.scale)
        dist_ref = torch.distributions.Normal(mu_ref, scale_ref)
        kl_loss = torch.distributions.kl_divergence(dist, dist_ref)
        if valid_mask is not None:
            kl_loss = kl_loss.squeeze()
            kl_loss = kl_loss.mean(dim=list(range(1, kl_loss.ndim)))
            kl_loss = kl_loss[valid_mask]
            kl_loss = kl_loss.mean() if kl_loss.numel() > 0 else torch.tensor(0.0, device=kl_loss.device)
        else:
            kl_loss = kl_loss.mean()
        terms['kl_loss'] = kl_loss

        # reconstruction loss
        if self.args.data_args.dataset == 'interhuman_single':
            rec_loss = self.masked_loss(future_motion_pred[...,:-4], future_motion_gt[...,:-4], valid_mask)
        else:
            rec_loss = self.masked_loss(future_motion_pred, future_motion_gt, valid_mask)
        terms['rec_loss'] = rec_loss
        
        B, T, *_ = future_motion_gt.shape
        if 'joints' in self.train_dataset.motion_repr:
            pred_joints = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))['joints'].reshape(B, T, -1, 3)
            gt_joints = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))['joints'].reshape(B, T, -1, 3) 
        else:
            feature_dict_pred = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))
            feature_dict_gt = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))
            update_dict = {
                'gender': cond['y']['gender'],
                'betas': cond['y']['betas'][:, history_length:],
            }
            feature_dict_pred.update(update_dict)
            feature_dict_gt.update(update_dict)
            for key in feature_dict_pred.keys():
                try:
                    if feature_dict_pred[key].shape[0] == B and feature_dict_pred[key].shape[1] == T:
                        original_shape = feature_dict_pred[key].shape
                        feature_dict_pred[key] = feature_dict_pred[key].reshape(B * T, *original_shape[2:])
                except:
                    continue
            for key in feature_dict_gt.keys():
                try:
                    if feature_dict_gt[key].shape[0] == B and feature_dict_gt[key].shape[1] == T:
                        original_shape = feature_dict_gt[key].shape
                        feature_dict_gt[key] = feature_dict_gt[key].reshape(B * T, *original_shape[2:])
                except:
                    continue
            pred_joints = self.train_dataset.primitive_utility.calc_joints_from_features(feature_dict_pred).reshape(B, T, -1, 3)
            gt_joints = self.train_dataset.primitive_utility.calc_joints_from_features(feature_dict_gt).reshape(B, T, -1, 3)
        
        # joints rec loss
        terms['joints_rec_loss'] = self.masked_loss(pred_joints, gt_joints, valid_mask)
        
        # vel loss
        pred_vel = pred_joints[:, 1:] - pred_joints[:, :-1]
        tgt_vel = gt_joints[:, 1:] - gt_joints[:, :-1]
        pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,))
        tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,))
        terms['vel_loss'] = self.masked_loss(pred_vel, tgt_vel, valid_mask)
        
        # fc loss
        feet_vel = pred_joints[:, 1:, self.fids, :] - pred_joints[:, :-1, self.fids,:] # [B, T-1, 4, 3]
        feet_h = pred_joints[:, :-1, self.fids, 1]                                     # [B, T-1, 4, 1]
        gt_feet_vel = gt_joints[:, 1:, self.fids, :] - gt_joints[:, :-1, self.fids,:]  # [B, T-1, 4, 3]
        gt_feet_h = gt_joints[:, :-1, self.fids, 1]                                    # [B, T-1, 4, 1]
        contact_gt = self.foot_detect(gt_feet_vel, gt_feet_h, self.feet_thre)          # [B, T-1, 4]
        contact = self.foot_detect(feet_vel, feet_h, self.feet_thre)                   # [B, T-1, 4]
        terms['fc_loss'] = self.masked_loss(feet_vel*contact_gt.unsqueeze(-1), torch.zeros_like(feet_vel), valid_mask)
        
        # bl loss
        pred_bones, gt_bones = [], []
        for chain in kinematic_chain:
            for i, joint in enumerate(chain[:-1]):
                pred_bone = (pred_joints[..., chain[i], :] - pred_joints[..., chain[i + 1], :]).norm(dim=-1, keepdim=True)  # [B,T,1]
                gt_bone = (gt_joints[..., chain[i], :] - gt_joints[..., chain[i + 1], :]).norm(dim=-1, keepdim=True)
                pred_bones.append(pred_bone)
                gt_bones.append(gt_bone)
        pred_bones = torch.cat(pred_bones, dim=-1)
        gt_bones = torch.cat(gt_bones, dim=-1)
        terms['bl_loss'] = self.masked_loss(pred_bones, gt_bones, valid_mask)
        
        # geo loss
        pred_rot = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))['body_pose'].reshape(B, T, -1, 6)
        gt_rot = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))['body_pose'].reshape(B, T, -1, 6)
        terms['geo_loss'] = self.cal_geo_loss(pred_rot, gt_rot, valid_mask)
        
        loss = train_args.weight_kl * kl_loss \
            + train_args.weight_rec * rec_loss \
            + train_args.weight_smpl_joints_rec * terms['joints_rec_loss'] \
            + train_args.weight_vel * terms['vel_loss'] \
            + train_args.weight_fc * terms['fc_loss'] \
            + train_args.weight_bl * terms['bl_loss'] \
            + train_args.weight_geo * terms['geo_loss']
        terms['loss'] = loss
        return terms

    def cal_geo_loss(self, pred_rot, gt_rot, valid_mask, eps=1e-7):
        B, T, *_ = gt_rot.shape
        pred_m = transforms.rotation_6d_to_matrix(pred_rot.reshape(-1,6))
        gt_m = transforms.rotation_6d_to_matrix(gt_rot.reshape(-1,6))
        m = torch.bmm(gt_m, pred_m.transpose(1, 2))
        cos = (m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] - 1) / 2.0
        theta = torch.acos(torch.clamp(cos, -1.0+eps, 1.0-eps)).reshape(B, T, -1)
        loss = theta.mean(dim=list(range(1, theta.ndim)))
        if valid_mask is not None:
            loss = loss[valid_mask]
            if loss.numel() == 0:
                return torch.tensor(0.0, device=theta.device)
        loss = loss.mean()
        return loss
    
    def foot_detect(self, feet_vel, feet_h, thres):
        velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor(
            [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device)

        feet_x = (feet_vel[..., 0]) ** 2
        feet_y = (feet_vel[..., 1]) ** 2
        feet_z = (feet_vel[..., 2]) ** 2

        contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float()
        return contact

    def train(self):
        model = self.model
        optimizer = self.optimizer
        args = self.args
        train_args = self.args.train_args
        writer = self.writer
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive
        
        model.train()
        total_steps = train_args.stage1_steps + train_args.stage2_steps + train_args.stage3_steps
        rest_steps = (total_steps - self.start_step) // self.train_dataset.num_primitive + 1
        rest_steps = rest_steps * self.train_dataset.num_primitive
        progress_bar = iter(tqdm(range(rest_steps)))
        self.step = self.start_step
        
        if self.args.do_eval:
            from evaluation.eval_inter_react_mvae import evaluation_during_training, EvalArgs
            eval_file = os.path.join(self.args.save_dir, 'evaluation_training.log')
            eval_args = EvalArgs()
            eval_args.load_dir = self.args.save_dir
            eval_args.dataset = self.args.data_args.dataset
            min_fid = np.inf
            max_top1 = -np.inf
            
        while self.step <= total_steps:
            # Annealing the rate if instructed to do so.
            if self.step < self.warmup_steps:
                self.scheduler.step()
            else:
                if train_args.anneal_lr:
                    frac = 1.0 - (self.step - self.warmup_steps - 1.0) / (total_steps - self.warmup_steps)
                    lrnow = frac * train_args.learning_rate
                    optimizer.param_groups[0]["lr"] = lrnow
            # print(f"Step {self.step}, learning rate: {optimizer.param_groups[0]['lr']}")

            with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                if self.mix_hml3d:
                    batch = self.train_dataset.get_batch(self.batch_size - int(self.batch_size*self.mix_rate))
                    batch_hml3d = self.hml3d_dataset.get_batch(int(self.batch_size*self.mix_rate))
                    
                    for i in range(len(batch)):
                        for key in batch[i].keys():
                            if key == 'primitive_padding_mask':
                                batch_hml3d[i][key] = torch.zeros((int(self.batch_size*self.mix_rate), *batch[i][key].shape[1:]), dtype=torch.bool, device=batch[i][key].device)
                                batch[i][key] = torch.cat([batch[i][key], batch_hml3d[i][key]], dim=0)
                            elif isinstance(batch[i][key], torch.Tensor):
                                batch[i][key] = torch.cat([batch[i][key], batch_hml3d[i][key]], dim=0)
                            elif isinstance(batch[i][key], list):
                                batch[i][key] += batch_hml3d[i][key]
                            else:
                                continue
                else:
                    batch = self.train_dataset.get_batch(self.batch_size)
            
            last_primitive = None
            for primitive_idx in range(num_primitive):
                with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                    motion, cond = self.get_primitive_batch(batch, primitive_idx)
                    motion_tensor = motion.squeeze(2).permute(0, 2, 1)  # [B, T, D]
                    future_motion_gt = motion_tensor[:, -future_length:, :]
                    history_motion = motion_tensor[:, :history_length, :]
                    if last_primitive is not None:
                        rollout_history = self.get_rollout_history(last_primitive, cond)
                        history_motion = rollout_history    # [B, H, D]
                    
                    if self.args.his_mask_prob > 0:
                        mask = torch.bernoulli(torch.ones((history_motion.size(0), history_length), device=history_motion.device) * self.args.his_mask_prob).view(history_motion.size(0), history_length, 1)
                        # mask = torch.rand(history_motion.size(0), 1).to(history_motion.device) < self.args.his_mask_prob
                        # If masked, set the entire history to zeros
                        history_motion = history_motion * (1.0 - mask)

                    latent, dist = model.encode(future_motion=future_motion_gt, history_motion=history_motion)
                    future_motion_pred = model.decode(latent, history_motion, nfuture=future_length)  # [B, F, D]

                    loss_dict = self.calc_loss(motion, cond, history_motion, future_motion_gt, future_motion_pred, latent, dist)
                    if loss_dict is None:
                        continue
                    loss = loss_dict['loss']

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), train_args.grad_clip)
                optimizer.step()

                # update the average model using exponential moving average
                if train_args.ema_decay > 0:
                    for param, avg_param in zip(self.model.parameters(), self.model_avg.parameters()):
                        avg_param.data.mul_(train_args.ema_decay).add_(
                            param.data, alpha=1 - train_args.ema_decay)

                last_primitive = None
                if self.step > train_args.stage1_steps:
                    rollout_prob = min(1.0, (self.step - train_args.stage1_steps) / max(
                        float(train_args.stage2_steps), 1e-6))
                    if torch.rand(1).item() < rollout_prob:
                        last_primitive = future_motion_pred.detach()

                if self.step % train_args.log_interval == 0:
                    for key in loss_dict:
                        writer.add_scalar(f"loss/{key}", loss_dict[key].item(), self.step)
                    writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], self.step)

                if self.step % train_args.save_interval == 0 or self.step == total_steps:
                    self.save()

                if self.step % train_args.val_interval == 0 or self.step == total_steps:
                    self.validate()
                
                if self.args.do_eval and self.step % self.args.eval_every_n_steps == 0:
                    original_mode = self.model.training
                    self.model.eval()
                    fid, mat, top1 = evaluation_during_training(self.args, eval_args, self.data_args, None, self.model, self.args, None,
                                                self.step, eval_file, device=self.device,)
                    if fid < min_fid:
                        min_fid = fid
                        self.save(self.args.save_dir / f"best_fid.pt")
                        print('Best FID Model So Far!~')
                    if top1 > max_top1:
                        max_top1 = top1
                        self.save(self.args.save_dir / f"best_top1.pt")
                        print('Best Top1 Model So Far!~')
                    writer.add_scalar('Test/FID', fid, self.step)
                    writer.add_scalar('Test/Matching', mat, self.step)
                    writer.add_scalar('Test/Top1', top1, self.step)
                    self.model.train(original_mode)

                self.step += 1
                next(progress_bar)

    def get_primitive_batch(self, batch, primitive_idx):
        motion = batch[primitive_idx]['motion_tensor_normalized']  # [bs, D, 1, T]
        if self.mode == 'merged':
            cond = {'y': {'gender_p1': batch[primitive_idx]['gender_p1'],
                          'betas_p1': batch[primitive_idx]['betas_p1'],  # [bs, T, 10]
                          'gender_p2': batch[primitive_idx]['gender_p2'],
                          'betas_p2': batch[primitive_idx]['betas_p2'],  # [bs, T, 10]
                          'history_motion': batch[primitive_idx]['history_motion'],  # [bs, D, 1, T]
                          'history_mask': batch[primitive_idx]['history_mask'],
                          'history_length': batch[primitive_idx]['history_length'],
                          'future_length': batch[primitive_idx]['future_length']
                          }
                    }
            if self.use_indi_text:
                cond['y']['text'] = batch[primitive_idx]['texts']
                if self.load_text_embedding:
                    cond['y']['text_embedding'] = batch[primitive_idx]['text_embedding']
        else:
            cond = {'y': {'gender': batch[primitive_idx]['gender'],
                        'betas': batch[primitive_idx]['betas'],  # [bs, T, 10]
                        'history_motion': batch[primitive_idx]['history_motion'],  # [bs, D, 1, T]
                        'history_mask': batch[primitive_idx]['history_mask'],
                        'history_length': batch[primitive_idx]['history_length'],
                        'future_length': batch[primitive_idx]['future_length']
                        }
                    }
            if self.use_indi_text:
                cond['y']['text'] = batch[primitive_idx]['texts']
                if self.load_text_embedding:
                    cond['y']['text_embedding'] = batch[primitive_idx]['text_embedding']
            if self.args.padding:
                cond['y']['primitive_padding_mask'] = batch[primitive_idx]['primitive_padding_mask']
        return motion, cond

    def get_rollout_history(self, last_primitive, cond,
                            return_transform=False,
                            transf_rotmat=None, transf_transl=None
                            ):
        """update history motion seed, update global transform"""
        motion_tensor = last_primitive[:, -self.train_dataset.history_length:, :]  # [B, T, D]
        if self.mix_hml3d:
            new_history_frames = torch.cat([self.train_dataset.denormalize(motion_tensor[:(self.batch_size - int(self.batch_size*self.mix_rate))]),
                                            self.hml3d_dataset.denormalize(motion_tensor[(self.batch_size - int(self.batch_size*self.mix_rate)):])], dim=0)
        else:
            new_history_frames = self.train_dataset.denormalize(motion_tensor)
        primitive_utility = self.train_dataset.primitive_utility
        rollout_history = []
        if self.mode == 'merged':
            genders_p1 = cond['y']['gender_p1']
            genders_p2 = cond['y']['gender_p2']
        else:
            genders = cond['y']['gender']
        new_transf_rotmat, new_transf_transl = [], []
        for gender_name in ['female', 'male', 'neutral']:
            if self.mode == 'merged':
                gender_idx = [idx for idx in range(len(genders_p1)) if genders_p1[idx] == gender_name]
                if len(gender_idx) == 0:
                    continue
                history_feature_dict_p1 = primitive_utility.tensor_to_dict(new_history_frames[gender_idx][...,:self.args.model_args.nfeats//2])
                history_feature_dict_p2 = primitive_utility.tensor_to_dict(new_history_frames[gender_idx][...,-self.args.model_args.nfeats//2:])
                history_feature_dict_p1.update(
                    {
                        'transf_rotmat': self.transf_rotmat.repeat(len(gender_idx), 1, 1) if transf_rotmat is None else transf_rotmat[gender_idx],
                        'transf_transl': self.transf_transl.repeat(len(gender_idx), 1, 1) if transf_transl is None else transf_transl[gender_idx],
                        'gender': gender_name,
                        'betas': cond['y']['betas_p1'][gender_idx, -self.train_dataset.history_length:, :],
                    }
                )
                history_feature_dict_p2.update(
                    {
                        'transf_rotmat': self.transf_rotmat.repeat(len(gender_idx), 1, 1) if transf_rotmat is None else transf_rotmat[gender_idx],
                        'transf_transl': self.transf_transl.repeat(len(gender_idx), 1, 1) if transf_transl is None else transf_transl[gender_idx],
                        'gender': gender_name,
                        'betas': cond['y']['betas_p2'][gender_idx, -self.train_dataset.history_length:, :],
                    }
                )
                pelvis_delta_p1 = primitive_utility.calc_calibrate_offset({
                    'betas': history_feature_dict_p1['betas'][:, 0, :],  # [B, 10]
                    'gender': gender_name,
                })
                pelvis_delta_p2 = primitive_utility.calc_calibrate_offset({
                    'betas': history_feature_dict_p2['betas'][:, 0, :],  # [B, 10]
                    'gender': gender_name,
                })
                history_feature_dict_p1['pelvis_delta'] = pelvis_delta_p1
                history_feature_dict_p2['pelvis_delta'] = pelvis_delta_p2
                use_predicted_joints = getattr(self.args.train_args, 'use_predicted_joints', False)
                canonicalized_history_primitive_dict_p1, blended_feature_dict_p1, transf_rotmat, transf_transl = primitive_utility.get_blended_feature(
                    history_feature_dict_p1, use_predicted_joints=use_predicted_joints, return_transf=True)
                canonicalized_history_primitive_dict_p2, blended_feature_dict_p2 = primitive_utility.relative_blended_feature(
                    history_feature_dict_p2, transf_rotmat, transf_transl, use_predicted_joints=use_predicted_joints)
                
                new_transf_rotmat.append(canonicalized_history_primitive_dict_p1['transf_rotmat'])
                new_transf_transl.append(canonicalized_history_primitive_dict_p1['transf_transl'])
                
                history_motion_tensor_p1 = primitive_utility.dict_to_tensor(blended_feature_dict_p1)
                history_motion_tensor_p2 = primitive_utility.dict_to_tensor(blended_feature_dict_p2)
                rollout_history.append(torch.cat((history_motion_tensor_p1, history_motion_tensor_p2), dim=-1))
            else:
                gender_idx = [idx for idx in range(len(genders)) if genders[idx] == gender_name]
                if len(gender_idx) == 0:
                    continue
                history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames[gender_idx])

                history_feature_dict.update(
                    {
                        'transf_rotmat': self.transf_rotmat.repeat(len(gender_idx), 1, 1) if transf_rotmat is None else transf_rotmat[gender_idx],
                        'transf_transl': self.transf_transl.repeat(len(gender_idx), 1, 1) if transf_transl is None else transf_transl[gender_idx],
                        'gender': gender_name,
                        'betas': cond['y']['betas'][gender_idx, -self.train_dataset.history_length:, :],
                    }
                )
                pelvis_delta = primitive_utility.calc_calibrate_offset({
                    'betas': history_feature_dict['betas'][:, 0, :],  # [B, 10]
                    'gender': gender_name,
                })
                history_feature_dict['pelvis_delta'] = pelvis_delta
                use_predicted_joints = getattr(self.args.train_args, 'use_predicted_joints', False)
                canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature(
                    history_feature_dict, use_predicted_joints=use_predicted_joints)
                new_transf_rotmat.append(canonicalized_history_primitive_dict['transf_rotmat'])
                new_transf_transl.append(canonicalized_history_primitive_dict['transf_transl'])
                history_motion_tensor = primitive_utility.dict_to_tensor(blended_feature_dict)
                rollout_history.append(history_motion_tensor)
            
        rollout_history = torch.cat(rollout_history, dim=0)
        if self.mix_hml3d:
            rollout_history = torch.cat([self.train_dataset.normalize(rollout_history[:(self.batch_size - int(self.batch_size*self.mix_rate))]),
                                         self.hml3d_dataset.normalize(rollout_history[(self.batch_size - int(self.batch_size*self.mix_rate)):])], dim=0)
        else:
            rollout_history = self.train_dataset.normalize(rollout_history)  # [B, T, D]
        # rollout_history = rollout_history.permute(0, 2, 1).unsqueeze(2)  # [B, D, 1, T_history]

        if return_transform:
            return rollout_history, torch.cat(new_transf_rotmat, dim=0), torch.cat(new_transf_transl, dim=0)
        else:
            return rollout_history

    def get_latent_scale(self, model):
        """
        get the scale of the latent space
            model: model or model_avg
        """
        original_mode = model.training
        model.eval()

        train_args = self.args.train_args
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        with torch.no_grad():
            batch = self.train_dataset.get_batch(self.batch_size)
            primitive_idx = 0
            motion, cond = self.get_primitive_batch(batch, primitive_idx)
            motion_tensor = motion.squeeze(2).permute(0, 2, 1)  # [B, T, D]
            future_motion_gt = motion_tensor[:, -future_length:, :]
            history_motion = motion_tensor[:, :history_length, :]

            latent, dist = model.encode(future_motion=future_motion_gt, history_motion=history_motion)  # [1, B, D]
            all_mean = latent.mean()
            all_std = (latent - all_mean).pow(2).mean().sqrt()
            model.register_buffer("latent_mean", all_mean)
            model.register_buffer("latent_std", all_std)
            print(f"latent mean: {all_mean}, latent std: {all_std}")

        model.train(original_mode)

    def save(self, checkpoint_path=None):
        model = self.model if self.model_avg is None else self.model_avg
        print('save avg model:', self.model_avg is not None)
        self.get_latent_scale(model)
        if checkpoint_path is None:
            checkpoint_path = self.args.save_dir / f"checkpoint_{self.step}.pt"
        torch.save({
            'num_steps': self.step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, checkpoint_path)
        print(f"Saved checkpoint at {checkpoint_path}")

    def validate(self):
        original_mode = self.model.training
        self.model.eval()

        model = self.model
        optimizer = self.optimizer
        args = self.args
        train_args = self.args.train_args
        writer = self.writer
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        with torch.no_grad():
            losses_dict = {}
            for _ in tqdm(range(max(128, len(self.val_dataset) // self.batch_size))):
                batch = self.val_dataset.get_batch(self.batch_size)
                last_primitive = None
                for primitive_idx in range(num_primitive):
                    motion, cond = self.get_primitive_batch(batch, primitive_idx)
                    motion_tensor = motion.squeeze(2).permute(0, 2, 1)  # [B, T, D]
                    future_motion_gt = motion_tensor[:, -future_length:, :]
                    history_motion = motion_tensor[:, :history_length, :]
                    if last_primitive is not None:
                        rollout_history = self.get_rollout_history(last_primitive, cond)
                        history_motion = rollout_history  # [B, H, D]

                    latent, dist = model.encode(future_motion=future_motion_gt, history_motion=history_motion)
                    future_motion_pred = model.decode(latent, history_motion, nfuture=future_length)

                    loss_dict = self.calc_loss(motion, cond, history_motion, future_motion_gt, future_motion_pred,
                                               latent, dist)
                    if loss_dict is None:
                        continue
                    for k, v in loss_dict.items():
                        if k not in losses_dict:
                            losses_dict[k] = []
                        losses_dict[k].append(v.detach())

                    if self.step > train_args.stage1_steps:
                        last_primitive = future_motion_pred.detach()
                    else:
                        last_primitive = None

        for k, v in losses_dict.items():
            losses_dict[k] = torch.stack(v).mean().item()
            self.writer.add_scalar(f"val_loss/{k}", losses_dict[k], self.step)
        self.model.train(original_mode)

    def close(self):
        self.writer.close()


if __name__ == "__main__":
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    args = tyro.cli(Args)
    trainer = Trainer(args)
    trainer.train()
    trainer.close()

