import os
import torch
import numpy as np

from engine.solver import Trainer
from util.metric_utils import visualization
from data_provider.build_dataloader import build_dataloader
from util.io_utils import load_yaml_config, instantiate_from_config
from models.interpretable_diffusion.model_utils import unnormalize_to_zero_to_one
from util.discriminative_metric import discriminative_score_metrics
from util.predictive_metric import predictive_score_metrics

# select 'sine'  'etth'  'energy'  'energy'
data = 'etth' 
channels = 7
class Args_Example:
    def __init__(self) -> None:
        self.config_path = f'./Config/{data}_cp.yaml'
        self.save_dir = f'./{data}_1'
        self.gpu = 0
        os.makedirs(self.save_dir, exist_ok=True)

args =  Args_Example()
configs = load_yaml_config(args.config_path)
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')

dl_info = build_dataloader(configs, args)
model = instantiate_from_config(configs['model']).to(device)
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info)

trainer.train()

dataset = dl_info['dataset']
seq_length, feature_dim = dataset.window, dataset.var_num
# ori_data = np.load(os.path.join(dataset.dir, f"etth_ground_truth_{seq_length}_train.npy"))
ori_data = np.load(os.path.join(dataset.dir, f"{data}_norm_truth_{seq_length}_train.npy"))  # Uncomment the line if dataset other than Sine is used.
fake_data = trainer.sample(num=len(dataset), size_every=1001, shape=[1024, channels])
# fake_data = np.load(os.path.join(args.save_dir, f'ddpm_fake_etth1.npy'))
# fake_data = dataset.unnormalize(fake_data)
if dataset.auto_norm:
    fake_data = unnormalize_to_zero_to_one(fake_data)
np.save(os.path.join(args.save_dir, f'ddpm_fake_{data}.npy'), fake_data)

l = min(len(ori_data), len(fake_data))
for i in range(10):
    print(discriminative_score_metrics(ori_data[:l, :, :], fake_data[:l, :, :]))
    print(predictive_score_metrics(ori_data[:l, :, :], fake_data[:l, :, :]))

# visualization(ori_data=ori_data, generated_data=fake_data, analysis='pca', compare=ori_data.shape[0])

# visualization(ori_data=ori_data, generated_data=fake_data, analysis='tsne', compare=ori_data.shape[0])

# visualization(ori_data=ori_data, generated_data=fake_data, analysis='kernel', compare=ori_data.shape[0])