import os
import pickle
import sys
import datetime
import logging
import os.path as osp
import numpy as np
import time
from omegaconf import OmegaConf

import torch

from mld.config import parse_args
from mld.data.get_data import get_dataset
from mld.models.modeltype.mld import MLD
from mld.models.modeltype.vae import VAE
from mld.utils.utils import set_seed, move_batch_to_device
from mld.data.humanml.utils.plot_script import plot_3d_motion
from mld.utils.temos_utils import remove_padding

from mld.transforms.joints2rots import config

import smplx
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from motion import Motion
from mld.models.pose_penetration_detector import PosePenetrationDetector


def load_example_hint_input(text_path: str) -> tuple:
    with open(text_path, "r") as f:
        lines = f.readlines()

    n_frames, control_type_ids, control_hint_ids = [], [], []
    for line in lines:
        s = line.strip()
        n_frame, control_type_id, control_hint_id = s.split(' ')
        n_frames.append(int(n_frame))
        control_type_ids.append(int(control_type_id))
        control_hint_ids.append(int(control_hint_id))

    return n_frames, control_type_ids, control_hint_ids


def load_example_input(text_path: str) -> tuple:
    with open(text_path, "r") as f:
        lines = f.readlines()

    texts, lens = [], []
    for line in lines:
        s = line.strip()
        s_l = s.split(" ")[0]
        s_t = s[(len(s_l) + 1):]
        lens.append(int(s_l))
        texts.append(s_t)
    return texts, lens


def check_penetration(joints):
    
    
    motion = Motion(joints=joints)
    
    
    stats = motion.check_penetration(threshold=0.01, batch_size=4)
    
    
    del motion
    torch.cuda.empty_cache()
    
    penetration_threshold = 0.01
    has_penetration = stats['penetration_rate'] > penetration_threshold
    
    return has_penetration, stats


def main():
    cfg = parse_args()
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    set_seed(cfg.SEED_VALUE)

    
    total_inference_time = 0.0
    total_samples = 0
    total_frames = 0
    inference_times = []
    per_frame_times = []

    if cfg.output_dir is None:
        name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
        cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
    else:
        
        os.makedirs(cfg.output_dir, exist_ok=True)
    
    vis_dir = osp.join(cfg.output_dir, 'samples')
    os.makedirs(cfg.output_dir, exist_ok=True)
    os.makedirs(vis_dir, exist_ok=False)

    steam_handler = logging.StreamHandler(sys.stdout)
    file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log'))
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
                        datefmt="%m/%d/%Y %H:%M:%S",
                        handlers=[steam_handler, file_handler])
    logger = logging.getLogger(__name__)

    OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml'))

    state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
    logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))

    # Step 1: Check if the checkpoint is VAE-based.
    is_vae = False
    vae_key = 'vae.skel_embedding.weight'
    if vae_key in state_dict:
        is_vae = True
    logger.info(f'Is VAE: {is_vae}')

    # Step 2: Check if the checkpoint is MLD-based.
    is_mld = False
    mld_key = 'denoiser.time_embedding.linear_1.weight'
    if mld_key in state_dict:
        is_mld = True
    logger.info(f'Is MLD: {is_mld}')

    # Step 3: Check if the checkpoint is LCM-based.
    is_lcm = False
    lcm_key = 'denoiser.time_embedding.cond_proj.weight'  # unique key for CFG
    if lcm_key in state_dict:
        is_lcm = True
        time_cond_proj_dim = state_dict[lcm_key].shape[1]
        cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
    logger.info(f'Is LCM: {is_lcm}')

    # Step 4: Check if the checkpoint is Controlnet-based.
    cn_key = "controlnet.controlnet_cond_embedding.0.weight"
    is_controlnet = True if cn_key in state_dict else False
    cfg.model.is_controlnet = is_controlnet
    logger.info(f'Is Controlnet: {is_controlnet}')

    if is_mld or is_lcm or is_controlnet:
        target_model_class = MLD
    else:
        target_model_class = VAE

    pose_discriminator = None
    if cfg.optimize:
        assert cfg.model.get('noise_optimizer') is not None
        cfg.model.noise_optimizer.params.optimize = True
        logger.info('Optimization enabled. Set the batch size to 1.')
        logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}')
        # cfg.TEST.BATCH_SIZE = 1
        # init a pose pene detector
        disc_cfg = OmegaConf.load(cfg.disc_cfg)
        cfg.disc_cfg = disc_cfg
        pose_discriminator = PosePenetrationDetector(
            input_dim=disc_cfg.model.input_dim,
            hidden_dims=disc_cfg.model.hidden_dims,
            output_dim=disc_cfg.model.output_dim,
            dropout=disc_cfg.model.dropout,
            activation=disc_cfg.model.activation,
            task_type=disc_cfg.model.task_type,
            scaling_factor=disc_cfg.model.scaling_factor
        ).to(device)
        
        
        pose_discriminator.load_state_dict(torch.load(disc_cfg.TEST.CHECKPOINTS)['model_state_dict'])
        pose_discriminator.eval()
        # pose_discriminator.requires_grad_(False)

    dataset = get_dataset(cfg)
    model = target_model_class(cfg, dataset)
    model.to(device)
    model.eval()
    model.requires_grad_(False)
    logger.info(model.load_state_dict(state_dict))

    FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")

    if cfg.example is not None:
        text, length = load_example_input(cfg.example)
        for t, l in zip(text, length):
            logger.info(f"{l}: {t}")

        batch = {"length": length, "text": text}
        
        for rep_i in range(cfg.replication):

            start_time = time.time()
            with torch.no_grad():
                joints = model(batch, pose_discriminator)[0]
            inference_time = time.time() - start_time
            
            
            total_inference_time += inference_time
            total_samples += len(joints)
            total_frames += sum(length)  
            inference_times.extend([inference_time / len(joints)] * len(joints))
            
            for l in length:
                per_frame_time = (inference_time / len(joints)) / l
                per_frame_times.append(per_frame_time)
                                
            num_samples = len(joints)
            for i in range(num_samples):
                res = dict()
                # Save final results
                pkl_path = osp.join(vis_dir, f"sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
                assert joints[i].shape == (length[i], 22, 3)
                res['joints'] = joints[i].detach().cpu().numpy()
                res['text'] = text[i]
                res['length'] = length[i]
                res['hint'] = batch.get('hint', None)
                with open(pkl_path, 'wb') as f:
                    pickle.dump(res, f)
                logger.info(f"Motions are generated here:\n{pkl_path}")

                if not cfg.no_plot:
                    plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
                                   text[i], fps=FPS)

    output_dir = cfg.output_dir
    with open("latest_output_dir.txt", "w") as f:
        f.write(os.path.join(output_dir, "samples"))


if __name__ == "__main__":
    main()
