import os
import yaml
import time
import random
import pickle
import numpy as np

import torch
from torch.utils.data import DataLoader, TensorDataset

from glob import glob
from tqdm import tqdm
from easydict import EasyDict

from models.diffusion import Diffusion

__all__ = [
    'sampling_eval'
]

def pde_loss(u: np.ndarray) -> np.ndarray:
    assert u.ndim == 3
    N = 128
    ts = np.linspace(0, 3, N)
    dt = ts[1] - ts[0]
    xs = np.linspace(-4, 4, N)
    dx = xs[1] - xs[0]

    u_t = np.diff(u, axis=1) / dt
    u_x = np.diff(u, axis=2) / dx
    return np.sqrt(np.square(u_t[:, :, :-1] + u[:, :-1, :-1] * u_x[:, :-1]).mean())

def get_new_work_dir(root: str, config_name: str) -> str:
    fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    if config_name.startswith('/'):
        log_dir = root + config_name + '-' + fn
    else:
        log_dir = root + '-'  + config_name + '-' + fn
    os.makedirs(log_dir)
    return log_dir
        
def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
        
def sampling_eval(ckpt_path: str, method: str | None = None) -> str:
    ckpt = torch.load(ckpt_path)
    config_path = glob(os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), '*.yml'))[0]
    with open(config_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))
    seed_all(config.train.seed)
    device = torch.device(config.model.device)

    if method is None:
        method = config.model.sampling.method

    output_dir = get_new_work_dir(os.path.dirname(os.path.dirname(ckpt_path)), f'/{method}_sample')

    test_loader = DataLoader(
        TensorDataset(
            torch.from_numpy(
                np.load(os.path.join(config.dataset_path, 'val.npy'))
            ).to(torch.float32)
        ),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )

    model = Diffusion(config.model).to(device)
    model.load_state_dict(ckpt['model'])

    results = []
    pbar = tqdm(test_loader, dynamic_ncols=True)
    for i, (x0, ) in enumerate(pbar):
        results.append(model.sampling(x0.shape, method).detach().cpu().numpy())

        with open(os.path.join(output_dir, 'samples_%d.pkl' % i), 'wb') as f:
            pickle.dump(np.concatenate(results), f)

        test_loss = np.round(pde_loss(np.concatenate(results)), decimals=5)
        pbar.set_description(f'Test loss: {test_loss}')

    for temp_file in os.listdir(output_dir):
        os.remove(os.path.join(output_dir, temp_file))

    samples_path = os.path.join(output_dir, f'samples_all_{test_loss}.pkl')
    with open(samples_path, 'wb') as f:
        pickle.dump(np.concatenate(results), f)

    return samples_path
        
if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn') 

    model_path = 'logs/temp---2024_07_20__21_43_20/checkpoints/ckpt.pt'

    method = 'ode'
    sampling_eval(model_path)
