import os
import sys
import yaml
import time
import random
import pickle
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader

from glob import glob
from tqdm import tqdm
from easydict import EasyDict

from eval_sample import eval_sample
from models.diffusion import Diffusion

__all__ = [
    'sampling'
]

def get_new_work_dir(root: str, config_name: str) -> str:
    fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    if config_name.startswith('/'):
        log_dir = root + config_name + '-' + fn
    else:
        log_dir = root + '-'  + config_name + '-' + fn
    os.makedirs(log_dir)
    return log_dir
        
def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
        
def sampling(ckpt_path: str, method: str | None = None, device: str | None = None) -> str:
    if device:
        device = torch.device(device)
        ckpt = torch.load(ckpt_path, map_location=device)
    else:
        ckpt = torch.load(ckpt_path)

    config_path = glob(os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), '*.yml'))[0]
    with open(config_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))
    seed_all(config.train.seed)

    if method is None:
        method = config.model.sampling.method
    if device is None:
        device = torch.device(config.model.device)

    output_dir = get_new_work_dir(os.path.dirname(os.path.dirname(ckpt_path)), f'/{method}_sample')

    test_loader = DataLoader(
        TensorDataset(torch.from_numpy(np.load(config.dataset.val)).to(torch.float32)),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )
    model = Diffusion(config.model).to(device)
    model.device = device
    model.load_state_dict(ckpt['model'])
    
    results = []
    pbar = tqdm(test_loader, dynamic_ncols=True)
    for i, data in enumerate(pbar):
        data_shape = data[0].shape
        results.append(model.sampling(data_shape, method).detach().cpu().numpy())

        with open(os.path.join(output_dir, 'samples_%d.pkl' % i), 'wb') as f:
            pickle.dump(np.concatenate(results), f)

    for temp_file in os.listdir(output_dir):
        os.remove(os.path.join(output_dir, temp_file))

    samples_path = os.path.join(output_dir, 'samples_all.pkl')
    with open(samples_path, 'wb') as f:
        pickle.dump(np.concatenate(results), f)

    return samples_path
        
if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn') 

    try:
        model_path = sys.argv[1]
    except:
        model_path = 'logs/naive/ParaGRU/naive---ParaGRU---hidden_size-256--n_layers-4--t_embed_size-128---2024_08_21__19_30_49'
    
    model_path += '/checkpoints/ckpt.pt'

    try:
        method = sys.argv[2]
    except:
        method = 'dpm3_30'

    try:
        device = sys.argv[3]
    except:
        device = None

    # print(model_path, method, device)
    eval_sample(sampling(model_path, method, device))

