# This code is based on https://github.com/openai/guided-diffusion
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import torch
from utils.fixseed import fixseed
from utils.parser_util import generate_args
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from utils import dist_util
from model.cfg_sampler import ClassifierFreeSampleModel
from model.traj_plan import *
from dataset.get_data import get_dimop3d_dataset_loader
from utils.joint2hml import recover_from_ric_to_raw
from utils.vis_joint2mesh import Joints2Obj
from utils.metrics import *


def main():
    args = generate_args()
    # fixseed(args.seed)
    out_path = args.output_dir

    dist_util.setup_dist(args.device)

    print("creating data loader...")
    data_root = args.data_dir if args.data_dir != "" else None
    data = get_dimop3d_dataset_loader(data_root, batch_size=args.batch_size, split="test")

    print("Creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(args)

    print(f"Loading checkpoints from [{args.model_path}]...")
    state_dict = torch.load(args.model_path, map_location='cpu')
    load_model_wo_clip(model, state_dict)

    if args.guidance_param != 1:
        model = ClassifierFreeSampleModel(model)   # wrapping model with the classifier-free sampler
    model.to(dist_util.dev())
    model.eval()  # disable random masking

    interest_net = torch.load(args.interestnet_path, map_location='cpu').to(dist_util.dev()).eval()
    hoi_estimator = torch.load(args.estimator_path, map_location='cpu').to(dist_util.dev()).eval()

    sample_length = data.dataset.fixed_length
    input_length = args.input_frames
    output_length = args.output_frames
    bs = args.batch_size = args.num_repetitions
    assert input_length + output_length == sample_length

    model_kwargs = {'y': {"lengths": torch.Tensor([sample_length]),
                          "masks": torch.BoolTensor([[1 for _ in range(sample_length)]]),
                          }}
    if args.guidance_param != 1:
        model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
    torch.set_printoptions(precision=4, sci_mode=False)

    heightmap_grid_spacing = 0.02            # DO NO MODIFY
    for i in range(len(data.dataset)):
        motion, refering_joints, scene_height_rgb, scene_base, recover, seq, scene_points, scene_feats, objects = data.dataset.__getitem__(i)

        motion = motion.unsqueeze(0).to(dist_util.dev())                        # [1, 263, 1, 159]
        scene_points = scene_points.unsqueeze(0).to(dist_util.dev())
        scene_feats = scene_feats.unsqueeze(0).to(dist_util.dev())
        refering_joints = refering_joints.unsqueeze(0)
        recover = recover.numpy()

        model_kwargs['y']["observed"] = motion[:, :, :, :input_length].clone().detach()
        model_kwargs['y']["sf"] = scene_feats.clone().detach()

        ######################### HOI Estimate ########################
        interest = interest_net(model_kwargs['y']["observed"], scene_points)
        interest_object = torch.Tensor([interest[obj].mean().item() for obj in objects])
        target_object = torch.multinomial(interest_object.softmax(dim=-1), num_samples=1)
        end_pose = hoi_estimator(scene_points[objects[target_object]])
        model_kwargs['y']["end_pose"] = end_pose.clone().detach()
        ###############################################################

        ####################### Action Planning #######################
        gt_traj = recover_from_ric_to_raw(motion, data.dataset, 22, recover)[:, 0]    # [T, 3], raw root positions
        scene_idx = torch.round(gt_traj[:, [0, 2]] / heightmap_grid_spacing).long() - scene_base[None, [0, 2]].long()
        traj_observed = gt_traj[:59, [0, 2]]

        traj_predict = plan_trajectory(scene_height_rgb, traj_observed, end_pose[:, 0, [0, 2]], input_length, output_length)
        traj_predict = (traj_predict + scene_base[None, [0, 2]].long()) * heightmap_grid_spacing  # [10, 2]
        traj_predict = traj_predict[None, :, :].repeat(bs, 1, 1)
        ###############################################################

        ################# Calculate Overwrite Vector ##################
        traj_predict = [torch.cat([traj_observed, bezier_11to101(traj_predict[n])[1:]], dim=0) for n in range(bs)]      # 159, 2
        traj_predict = torch.stack(traj_predict, dim=0)                                 # bs, 159, 2
        unnorm_motion = data.dataset.inv_transform(motion[:, :, 0].permute(0, 2, 1).detach().cpu().clone())
        rotation_at_observed_end = unnorm_motion[:, :input_length, 0]                   # 1, 60

        traj_overwrite = calculate_overwrite(traj_predict, refering_joints, rotation_at_observed_end, input_length, output_length)
        traj_overwrite = (traj_overwrite - data.dataset.mean[None, None, :3]) / data.dataset.std[None, None, :3]  # 1, 159, 3

        # Due to the bezier curve of the trajectory could slightly affect the observed route,
        # so we overwrite the whole sequence, while the different in observation is small enough to omit
        model_kwargs['y']["trajectory"] = traj_overwrite.permute(0, 2, 1).unsqueeze(2).contiguous().detach().to(dist_util.dev())  # 1, 3, 1, 159
        ###############################################################

        sample = diffusion.p_sample_loop(
            model,
            (args.batch_size, model.njoints, model.nfeats, output_length + input_length),
            cond=model_kwargs,
            skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )
        assert sample.shape[-1] == input_length + output_length

        # Recover XYZ *positions* from HumanML3D vector representation
        sample = [recover_from_ric_to_raw(sample[n:n+1], data.dataset, 22, recover) for n in range(bs)]
        sample = torch.stack(sample, dim=0)

        joint2obj = Joints2Obj(sample.permute(0, 2, 3, 1))
        joint2obj.save_obj(save_path="./vis")
