import os
from os.path import join as pjoin

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
from torch.utils.data import DataLoader

from model.vq.rvq_model import RVQVAE, HRVQVAE
from model.vq.hvq_model import HVQVAE
from model.evaluator.evaluator_wrapper import EvaluatorWrapper
from model.transformer.transformer import MoMaskPlus
from trainers.transformer_trainer import MaskTransformerTrainer
from config.load_config import load_config

from dataset.dataset import MotionDataset, TextMotionDataset
from utils.paramUtil import kinematic_chain
from utils import bvh_io
from utils.utils import plot_3d_motion
from common.skeleton import Skeleton
from utils.motion_process_bvh import recover_pos_from_rot
import numpy as np
from utils.fixseeds import *

import shutil

def forward_kinematic_func(data):
    motions = train_dataset.inv_transform(data)
    global_pos = recover_pos_from_rot(motions, 
                                      joints_num=cfg.data.joint_num, 
                                      skeleton=skeleton)
    return global_pos

def plot_t2m(data, save_dir, captions, m_lengths):
    global_pos = forward_kinematic_func(data).detach().cpu().numpy()
    # data = train_dataset.inv_transform(data)
    for i in range(len(global_pos)):
        save_path = pjoin(save_dir, '%02d.mp4' % (i))
        plot_3d_motion(save_path, 
                       kinematic_chain, 
                       global_pos[i, :m_lengths[i]], 
                       title=captions[i], 
                       fps=30, 
                       radius=100)


def load_vq_model(cfg, device):
    # print(cfg.exp)
    vq_cfg = load_config(pjoin(cfg.exp.root_ckpt_dir, cfg.data.name, 'vq', cfg.vq_name, 'residual_vqvae.yaml'))

    vq_model = None  # Initialize vq_model to None to avoid unbound errors
    if 'hvq' in cfg.vq_name:
        vq_model = HVQVAE(
                vq_cfg,
                vq_cfg.data.dim_pose,
                vq_cfg.model.output_emb_width_b,
                vq_cfg.model.down_t,
                vq_cfg.model.down_b,
                vq_cfg.model.stride_t,
                vq_cfg.model.width,
                vq_cfg.model.depth,
                vq_cfg.model.dilation_growth_rate,
                vq_cfg.model.vq_act,
                vq_cfg.model.vq_norm)
        
    elif 'hrvq' in cfg.vq_name:
        vq_model = HRVQVAE(vq_cfg,
                vq_cfg.data.dim_pose,
                vq_cfg.model.down_t,
                vq_cfg.model.stride_t,
                vq_cfg.model.width,
                vq_cfg.model.depth,
                vq_cfg.model.dilation_growth_rate,
                vq_cfg.model.vq_act,
                vq_cfg.model.use_attn,
                vq_cfg.model.vq_norm)
        
    elif 'rvq' in cfg.vq_name:
        vq_model = RVQVAE(vq_cfg,
                    vq_cfg.data.dim_pose,
                    vq_cfg.model.down_t,
                    vq_cfg.model.stride_t,
                    vq_cfg.model.width,
                    vq_cfg.model.depth,
                    vq_cfg.model.dilation_growth_rate,
                    vq_cfg.model.vq_act,
                    vq_cfg.model.vq_norm)
    ckpt = torch.load(pjoin(vq_cfg.exp.root_ckpt_dir, vq_cfg.data.name, 'vq', vq_cfg.exp.name, 'model',cfg.vq_ckpt),
                            map_location=device, weights_only=True)
    model_key = 'vq_model' if 'vq_model' in ckpt else 'model'
    vq_model.load_state_dict(ckpt[model_key]) # type: ignore
    print(f'Loading VQ Model {vq_cfg.exp.name} from epoch {ckpt["ep"]}')
    vq_model.to(device) # type: ignore
    vq_model.eval() # type: ignore
    return vq_model, vq_cfg



if __name__ == "__main__":
    # torch.autograd.set_detect_anomaly(True)
    cfg = load_config('config/train_momaskplus.yaml')
    cfg.exp.checkpoint_dir = pjoin(cfg.exp.root_ckpt_dir, cfg.data.name, 'momask_plus', cfg.exp.name)

    if cfg.exp.is_continue:
        n_cfg = load_config(pjoin(cfg.exp.checkpoint_dir, 'train_momaskplus.yaml'))
        n_cfg.exp.is_continue = True
        n_cfg.exp.device = cfg.exp.device
        n_cfg.exp.checkpoint_dir = cfg.exp.checkpoint_dir
        cfg = n_cfg
        # print(cfg)
    else:
        os.makedirs(cfg.exp.checkpoint_dir, exist_ok=True)
        shutil.copy('config/train_momaskplus.yaml', cfg.exp.checkpoint_dir)

    fixseed(cfg.exp.seed)

    if cfg.exp.device != 'cpu':
        torch.cuda.set_device(cfg.exp.device)

    torch.autograd.set_detect_anomaly(True)

    device = torch.device(cfg.exp.device)

    cfg.exp.model_dir = pjoin(cfg.exp.checkpoint_dir, 'model')
    cfg.exp.eval_dir = pjoin(cfg.exp.checkpoint_dir, 'animation')
    cfg.exp.log_dir = pjoin(cfg.exp.root_log_dir, cfg.data.name, 'momask_plus',cfg.exp.name)

    os.makedirs(cfg.exp.model_dir, exist_ok=True)
    os.makedirs(cfg.exp.eval_dir, exist_ok=True)
    os.makedirs(cfg.exp.log_dir, exist_ok=True)

    cfg.data.feat_dir = pjoin(cfg.data.root_dir, 'renamed_feats')
    meta_dir = pjoin(cfg.data.root_dir, 'meta_data')
    data_split_dir = pjoin(cfg.data.root_dir, 'data_split_info1')
    all_caption_path = pjoin(cfg.data.root_dir, 'all_caption_clean.json')

    train_mid_split_file = pjoin(data_split_dir, 'train_fnames.txt')
    train_cid_split_file = pjoin(data_split_dir, 'train_ids.txt')

    val_mid_split_file = pjoin(data_split_dir, 'val_fnames.txt')
    val_cid_split_file = pjoin(data_split_dir, 'val_ids.txt')

    template_anim = bvh_io.load(pjoin(cfg.data.root_dir, 'renamed_bvhs', 'm_ep2_00086.bvh'))
    skeleton = Skeleton(template_anim.offsets, template_anim.parents, device=device)

    mean = np.load(pjoin(meta_dir, 'mean.npy'))
    std = np.load(pjoin(meta_dir, 'std.npy'))
    
    vq_model, vq_cfg = load_vq_model(cfg, device=device)

    if 'rvq' in cfg.vq_name:
        cfg.vq = vq_cfg.quantizer
    elif 'hvq' in cfg.vq_name:
        cfg.vq = vq_cfg.quantizer
        cfg.vq.nb_code = vq_cfg.quantizer.nb_code_t
        cfg.vq.code_dim = vq_cfg.quantizer.code_dim_t
    
    # cfg.vq.nb_code = vq_cfg.quantizer.nb_code

    t2m_transformer = MoMaskPlus(
        code_dim=cfg.vq.code_dim,
        latent_dim=cfg.model.latent_dim,
        ff_size=cfg.model.ff_size,
        num_layers=cfg.model.n_layers,
        num_heads=cfg.model.n_heads,
        dropout=cfg.model.dropout,
        text_dim=cfg.text_embedder.dim_embed,
        cond_drop_prob=cfg.training.cond_drop_prob,
        device=device,
        cfg=cfg,
        full_length=cfg.data.max_motion_length//4,
        scales=vq_cfg.quantizer.scales
    )

    pc_vq = sum(param.numel() for param in t2m_transformer.parameters())
    print(t2m_transformer)
    # print("Total parameters of discriminator net: {}".format(pc_vq))
    # all_params += pc_vq_dis

    print('Total parameters of all models: {}M'.format(pc_vq/1000_000))
    print(device)

    trainer = MaskTransformerTrainer(cfg, t2m_transformer, vq_model=vq_model, device=device)


    train_dataset = TextMotionDataset(cfg, mean, std, train_mid_split_file, train_cid_split_file, all_caption_path)
    val_dataset = TextMotionDataset(cfg, mean, std, val_mid_split_file, val_cid_split_file, all_caption_path)
    eval_dataset = TextMotionDataset(cfg, mean, std, val_mid_split_file, val_cid_split_file, all_caption_path)

    train_loader = DataLoader(train_dataset, batch_size=cfg.training.batch_size, drop_last=True, num_workers=8,
                              shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg.training.batch_size, drop_last=True, num_workers=8,
                              shuffle=True, pin_memory=True)

    eval_cfg = load_config(pjoin('checkpoint_dir/snapmotion/evaluator/eval_klde-5_late-5_nlayer6/evaluator.yaml'))
    eval_wrapper = EvaluatorWrapper(eval_cfg, device=device)

    eval_loader = DataLoader(eval_dataset, batch_size=eval_cfg.matching_pool_size, drop_last=True, num_workers=8,
                              shuffle=True, pin_memory=True)

    trainer.train(train_loader, val_loader, eval_loader, eval_wrapper, plot_t2m)
