import argparse
import torch
import time
import datetime
import json
import yaml
import os
import pickle
import numpy as np
import logging

from dataset import *

from diff_func import CSDI_ETT
from pr_func import CSDI_PR

from utils import train, train_pr_em, evaluate, evaluate_pr

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

parser = argparse.ArgumentParser(description="CSDI")
parser.add_argument("--data_name", default="ett", choices=['ett', 'stock', 'pems'], type=str)

parser.add_argument('--device', default='cuda:0', help='Device for Attack')
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--testmissingratio", type=float, default=0.1, help='artificial missing ratio for evalution')

parser.add_argument("--schedule", default="quad", choices=['linear', 'quad', 'cosine'], type=str)
parser.add_argument("--unconditional", action="store_true")
parser.add_argument("--ort", action="store_true")
parser.add_argument("--maskinfo", action="store_true")
parser.add_argument("--target", default="x0", choices=['epsilon', 'x0'], type=str)
parser.add_argument("--missingpattern", choices=['random', 'near'], default='random', type=str, help='how to make missing target')

parser.add_argument("--modelfolder", type=str, default=None)
parser.add_argument("--nsample", type=int, default=10)
parser.add_argument("--mitratio", type=float, default=0.1, help='ratio of making missing target artificially in training procedure')

parser.add_argument("--pr", action="store_true")
parser.add_argument("--prscale", type=float, default=1)

parser.add_argument("--emstep_epoch", type=int, default=None)

parser.add_argument("--eval", action="store_true")
parser.add_argument("--ckpt", type=int, default=None)
parser.add_argument("--gen_iter", type=int, default=5)

args = parser.parse_args()
args.config = f"config_{args.data_name}.yaml"
if args.data_name =='pems':
    args.eval_length = 12
else:
    args.eval_length = 24
print(args)

path = "config/" + args.config
with open(path, "r") as f:
    config = yaml.safe_load(f)

config["model"]["is_unconditional"] = args.unconditional
config["model"]["test_missing_ratio"] = args.testmissingratio
config['model']['is_ort'] = args.ort
config['model']['mitratio'] = args.mitratio
config['model']['is_maskinfo'] = args.maskinfo
config['model']['target'] = args.target
config['model']['target_strategy'] = args.missingpattern

config['diffusion']['schedule'] = args.schedule

config['pr']['is_pr'] = args.pr
config['pr']['scale'] = args.prscale

if args.emstep_epoch is not None:
    config['train']['epochs_pr'] = args.emstep_epoch

print(json.dumps(config, indent=4))

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

if args.modelfolder is None:
    foldername = f"./save/{args.data_name}_seed{args.seed}_{current_time}"
    print(foldername)
    os.makedirs(foldername, exist_ok=True)
    with open(foldername + "/config.json", "w") as f:
        json.dump(config, f, indent=4)
else:
    foldername = f"./save/{args.modelfolder}"

log_name = f'/log_eval_{args.eval}'
if args.ckpt is not None:
    log_name = log_name+f'_ckpt_{args.ckpt}_scale_{args.prscale}' 
file_handler = logging.FileHandler(foldername + log_name + '.txt', mode='w')
file_handler.setFormatter(logging.Formatter(
    '[%(asctime)s] %(levelname)s - %(message)s', '%Y-%m-%d %H:%M:%S'
))

console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
    '%(levelname)s: %(message)s'
))

logger.addHandler(file_handler)
logger.addHandler(console_handler)

logger.info("Get Training Information...")
train_info = get_train_dataloader(
    data_name = args.data_name,
    eval_length= args.eval_length, 
    seed=args.seed,
    batch_size=config["train"]["batch_size"],
    missing_ratio=config["model"]["test_missing_ratio"],
)

train_dataset, train_loader = train_info['dataset'], train_info['dataloader']
train_omr, _, _ = train_dataset.get_dataset_info()
logger.info(f"Original Missing ratio of Training Dataset is {train_omr*100:.4f}%.")
train_dataset.save_eval_info(foldername)

logger.info("Get Test Information")
test_info = get_test_dataloader(
    data_name = args.data_name, 
    eval_length= args.eval_length,
    seed=args.seed,
    batch_size=config["train"]["batch_size"],
    missing_ratio=config["model"]["test_missing_ratio"],
)

test_dataset, test_loader = test_info['dataset'], test_info['dataloader']
test_omr, test_amr, test_target_num = test_dataset.get_dataset_info()
logger.info(f"Original Missing ratio of Test Dataset is {test_omr*100:.4f}%.")
logger.info(f"Artificial Missing ratio of Test Dataset is {test_amr*100:.4f}%.")
logger.info(f"The Number of Evaluating Target Entries of Test Dataset is {int(test_target_num)}.")
test_dataset.save_eval_info(foldername)


val_info = get_val_dataloader(
    data_name = args.data_name, 
    eval_length= args.eval_length,
    seed=args.seed,
    batch_size=config["train"]["batch_size"],
    missing_ratio=config["model"]["test_missing_ratio"],
)

val_dataset, val_loader = val_info['dataset'], val_info['dataloader']
val_omr, val_amr, val_target_num = val_dataset.get_dataset_info()
logger.info(f"Original Missing ratio of Valid Dataset is {val_omr*100:.4f}%.")
logger.info(f"Artificial Missing ratio of Valid Dataset is {val_amr*100:.4f}%.")
logger.info(f"The Number of Evaluating Target Entries of Valid Dataset is {int(val_target_num)}.")
val_dataset.save_eval_info(foldername)

if args.data_name == 'ett':
    target_dim = 7
elif args.data_name == 'stock':
    target_dim = 6
elif args.data_name =='pems':
    target_dim = 325

model = CSDI_ETT(config, args.device, target_dim=target_dim, logger=logger).to(args.device)
logger.info('Diffusion Model Loaded.')

if args.modelfolder is None:
    logger.info('Train the diffusion model from scratch.')
    start_time = time.time()
    train(
        model,
        config["train"],
        train_loader,
        valid_loader=val_loader,
        foldername=foldername,
        logger = logger,
    )
    gen_iter = 1
    logger.info(f'First step diffusion training for {time.time() - start_time}')
else:
    foldername = os.path.join("./save", args.modelfolder)
    logger.info('Load the pre-trained diffusion model.')
    model.load_state_dict(torch.load(os.path.join(foldername, f"model_{args.target}.pth")))
    gen_iter = args.gen_iter

if not args.eval:
    logger.info(f'Impute Artificial Missing Test dataset with Iter {args.nsample}....')
    _, _, _ = evaluate(model, test_loader, config, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='test', target_mode='artificial', logger=logger)
    logger.info(f'Impute Original Missing Test dataset with Iter {args.nsample}....')
    _, _, _ = evaluate(model, test_loader, config, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='test', target_mode='original', logger=logger)
    logger.info(f'Impute Original Missing Training dataset with Iter {args.nsample}....')
    _, _, _ = evaluate(model, train_loader, config, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='train', target_mode='original', logger=logger)

if args.pr:
    logger.info('Pattern Recognizer Loaded.')
    model_pr = CSDI_PR(config, args.device, target_dim=target_dim, logger=logger).to(args.device)

    if args.ckpt is None:
        logger.info('Train the pattern recognizer from scratch.')
        start_time = time.time()
        train_pr_em(
            model,
            model_pr,
            config["train"],
            train_loader,
            foldername=foldername,
            logger=logger,
            scale=config['pr']['scale']
            )
        logger.info(f'Second step diffusion & pattern recognizer training for {time.time() - start_time}')
    else:
        logger.info(f'Load the pre-trained pattern recognizer model at EM {args.ckpt}.')
        model_pr.load_state_dict(torch.load(os.path.join(foldername, f"model_pr_em{args.ckpt}.pth")))
        logger.info(f'Load the pre-trained fine-tuned diffusion model at EM {args.ckpt}.')
        model.load_state_dict(torch.load(os.path.join(foldername, f"model_ft_em{args.ckpt}.pth")))

    artf_test = {'RMSE':[], 'MAE':[], 'MRE':[]}
    orig_test = {'RMSE':[], 'MAE':[], 'MRE':[]}
    orig_train = {'RMSE':[], 'MAE':[], 'MRE':[]}

    for _ in range(gen_iter):
        logger.info(f'Impute Artificial Missing Test dataset with Iter {args.nsample} + PR guidance {args.prscale}')
        RMSE1, MAE1, MRE1 = evaluate_pr(model, model_pr, config, test_loader, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='test', target_mode='artificial', scale=config['pr']['scale'], logger=logger)
        artf_test['RMSE'].append(RMSE1)
        artf_test['MAE'].append(MAE1)
        artf_test['MRE'].append(MRE1)
        logger.info(f'Impute Original Missing Test dataset with Iter {args.nsample} + PR guidance {args.prscale}' )
        RMSE2, MAE2, MRE2 = evaluate_pr(model, model_pr, config, test_loader, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='test', target_mode='original', scale=config['pr']['scale'], logger=logger)
        orig_test['RMSE'].append(RMSE2)
        orig_test['MAE'].append(MAE2)
        orig_test['MRE'].append(MRE2)
        logger.info(f'Impute Original Missing Training dataset with Iter {args.nsample} + PR guidance {args.prscale}')
        RMSE3, MAE3, MRE3 = evaluate_pr(model, model_pr, config, train_loader, nsample=args.nsample, scaler=1, foldername=foldername, data_mode='train', target_mode='original', scale=config['pr']['scale'], logger=logger)
        orig_train['RMSE'].append(RMSE3)
        orig_train['MAE'].append(MAE3)
        orig_train['MRE'].append(MRE3)
    if args.eval and gen_iter!=1:
        logger.info(f"artf test: {np.mean(artf_test['RMSE']):.4f}, {np.std(artf_test['RMSE']):.4f}, {np.mean(artf_test['MAE']):.4f}, {np.std(artf_test['MAE']):.4f}, {np.mean(artf_test['MRE']):.4f}, {np.std(artf_test['MRE']):.4f}")
        logger.info(f"orig test: {np.mean(orig_test['RMSE']):.4f}, {np.std(orig_test['RMSE']):.4f}, {np.mean(orig_test['MAE']):.4f}, {np.std(orig_test['MAE']):.4f}, {np.mean(orig_test['MRE']):.4f}, {np.std(orig_test['MRE']):.4f}")
        logger.info(f"orig train: {np.mean(orig_train['RMSE']):.4f}, {np.std(orig_train['RMSE']):.4f}, {np.mean(orig_train['MAE']):.4f}, {np.std(orig_train['MAE']):.4f}, {np.mean(orig_train['MRE']):.4f}, {np.std(orig_train['MRE']):.4f}")
