# This code is based on https://github.com/openai/guided-diffusion
"""
Train a diffusion model on images.
"""

import os
import json
from utils.fixseed import fixseed
from utils.parser_util import train_args
from utils import dist_util
from train.training_loop import TrainLoop
from data_loaders.get_data import get_dataset_loader
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform  # required for the eval operation
import sys
import torch

sys.path.append('/export/home/Working/shiyu-project/motion')

class set_args():
    def __init__(self):
        self.seed = 10
        self.dataset = "humanml"
        self.batch_size = 64
        self.num_frames = 60
        self.train_platform_type = "NoPlatform"
        self.save_dir = "save/2_ft/"
        self.overwrite = True
        self.device = 0
        self.latent_dim = 512
        self.layers = 8
        self.cond_mask_prob = 0.1
        self.arch = "trans_enc"
        self.emb_trans_dec = False
        self.noise_schedule = "cosine"
        self.sigma_small = True
        self.lambda_vel = 0.0
        self.lambda_rcxyz = 0.0
        self.lambda_fc = 0.0
        self.lr = 1e-4
        self.log_interval = 1_000
        self.save_interval = 50_000
        self.resume_checkpoint = ""
        self.weight_decay = 0.0
        self.lr_anneal_steps = 0
        self.num_steps = 600_000
        self.eval_during_training = False
        self.eval_split = "test"
        self.eval_batch_size = 32
        self.eval_num_samples = 1_000
        self.eval_rep_times = 3
        self.unconstrained = False
        self.uncond = False


args = set_args()

# def main():
# args = train_args()
fixseed(args.seed)
train_platform_type = eval(args.train_platform_type)
train_platform = train_platform_type(args.save_dir)
train_platform.report_args(args, name='Args')

if args.save_dir is None:
    raise FileNotFoundError('save_dir was not specified.')
elif os.path.exists(args.save_dir) and not args.overwrite:
    raise FileExistsError('save_dir [{}] already exists.'.format(args.save_dir))
elif not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)
args_path = os.path.join(args.save_dir, 'args.json')
with open(args_path, 'w') as fw:
    json.dump(vars(args), fw, indent=4, sort_keys=True)

dist_util.setup_dist(args.device)

print("creating data loader...")
data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=args.num_frames) # name = humanml, batch_size = 64, num_frames = 60

print("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, data)
# model.to(dist_util.dev())
model.rot2xyz.smpl_model.eval()

# state_dict = torch.load("/export/home/Working/shiyu-project/motion/save/uncond/my_humanml_trans_enc_512/model000111000.pt")
# model.load_state_dict(state_dict, strict=False)
model.to(dist_util.dev())
model.train()

print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters_wo_clip()) / 1000000.0))
print("Training...")



TrainLoop(args, train_platform, model, diffusion, data, rate = 0.4).run_loop()
train_platform.close()