import os
import sys
import yaml
import time
import random
import pickle
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader

from glob import glob
from tqdm import tqdm
from easydict import EasyDict

from models.diffusion import Diffusion

__all__ = [
    'sampling'
]

def get_new_work_dir(root: str, config_name: str) -> str:
    fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    if config_name.startswith('/'):
        log_dir = root + config_name + '-' + fn
    else:
        log_dir = root + '-'  + config_name + '-' + fn
    os.makedirs(log_dir)
    return log_dir
        
def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
        
def sampling(ckpt_path: str, method = None) -> str:
    ckpt = torch.load(ckpt_path)
    config_path = glob(os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), '*.yml'))[0]
    with open(config_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))
    seed_all(config.train.seed)
    device = torch.device(config.model.device)

    if method is None:
        method = config.model.sampling.method

    output_dir = get_new_work_dir(os.path.dirname(os.path.dirname(ckpt_path)), f'/{method}_sample')

    test_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(np.load(config.dataset.val_x)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.val_edge)).to(torch.float32),
            torch.from_numpy(np.load(config.dataset.val_energy)).to(torch.float32)
        ),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )

    model = Diffusion(config.model).to(device)
    model.load_state_dict(ckpt['model'])

    results = []
    pbar = tqdm(test_loader, dynamic_ncols=True)
    for i, (feature, edge, energy) in enumerate(pbar):
        results.append(model.sampling(
            feature.shape,
            edge.to(device), energy.to(device), method
        ).detach().cpu().numpy())

        with open(os.path.join(output_dir, 'samples_%d.pkl' % i), 'wb') as f:
            pickle.dump(np.concatenate(results), f)

    for temp_file in os.listdir(output_dir):
        os.remove(os.path.join(output_dir, temp_file))

    samples_path = os.path.join(output_dir, 'samples_all.pkl')
    with open(samples_path, 'wb') as f:
        pickle.dump(np.concatenate(results), f)

    return samples_path
        
if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn') 

    try:
        model_path = sys.argv[1]
    except:
        model_path = 'logs/EGNN_GRU/5spring---naive---EGNN_GRU---rnn_hidden_size-512--t_embed_size-64--gnn_hidden_size-256--n_layers-3---2024_08_24__22_22_20'
    
    model_path += '/checkpoints/ckpt.pt'

    try:
        method = sys.argv[2]
    except:
        method = 'dpm3_30'


    # print(model_path, method, device)
    # quit()
    sampling(model_path, method)

