import os
import yaml
import time
import random
import pickle
import numpy as np

import torch

from torch.utils.data import DataLoader

from glob import glob
from tqdm import tqdm
from easydict import EasyDict

from torch import Tensor
from utils.SWDataset import SWDataset
from models.diffusion import Diffusion

__all__ = [
    'sampling_eval'
]

@torch.no_grad()
def sw_pde_loss(
        data: Tensor,
        dx: float, dy: float, dt: float,
        gdr: Tensor, coriolis: Tensor
    ):
    # gravity.shape, depth.shape == (batch_size, )
    gravity, depth = gdr[:, :2].permute(1, 0)
    gravity = gravity.reshape(-1, 1, 1, 1)
    depth = depth.reshape(-1, 1, 1, 1)


    # h.shape, u.shape, v.shape == (batch_size, n_time, img_size, img_size)
    h, u, v = data.permute(2, 0, 1, 3, 4)
    h = h + depth

    v_avg = 0.25 * (v[:, :, 1:-1, 1:-1] + v[:, :, :-2, 1:-1] + v[:, :, 1:-1, 2:] + v[:, :, :-2, 2:])
    u_avg = 0.25 * (u[:, :, 1:-1, 1:-1] + u[:, :, 1:-1, :-2] + u[:, :, 2:, 1:-1] + u[:, :, 2:, :-2])

    dudt = torch.diff(u, dim=1)[:, :, 1:-1, 1:-1] / dt
    dvdt = torch.diff(v, dim=1)[:, :, 1:-1, 1:-1] / dt
    dhdt = torch.diff(h, dim=1)[:, :, 1:-1, 1:-1] / dt

    dhdx = (h[:, :, 1:-1, 2:] - h[:, :, 1:-1, 1:-1]) / dx
    dhdy = (h[:, :, 2:, 1:-1] - h[:, :, 1:-1, 1:-1]) / dy
    dudx = (u[:, :, 1:-1, 1:-1] - u[:, :, 1:-1, :-2]) / dx
    dvdy = (v[:, :, 1:-1, 1:-1] - v[:, :, :-2, 1:-1]) / dy

    loss1 = dudt - (coriolis * v_avg - gravity * dhdx)[:, :-1]
    loss2 = dvdt + coriolis * u_avg[:, 1:] + gravity * dhdy[:, :-1]
    loss3 = dhdt + depth * (dudx + dvdy)[:, 1:]
    loss = torch.stack([loss1, loss2, loss3], dim=1).square()
    return 1e5 * loss.mean().sqrt().item()

def get_new_work_dir(root: str, config_name: str) -> str:
    fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
    if config_name.startswith('/'):
        log_dir = root + config_name + '-' + fn
    else:
        log_dir = root + '-'  + config_name + '-' + fn
    os.makedirs(log_dir)
    return log_dir
        
def seed_all(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
        
def listtensor2numpy(x) -> np.ndarray:
    return torch.concatenate(x, dim=0).detach().cpu().numpy()

def sampling_eval(ckpt_path: str) -> str:
    ckpt = torch.load(ckpt_path)
    config_path = glob(os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), '*.yml'))[0]
    with open(config_path, 'r') as f:
        config = EasyDict(yaml.safe_load(f))
    seed_all(config.train.seed)
    device = torch.device(config.model.device)

    output_dir = get_new_work_dir(os.path.dirname(os.path.dirname(ckpt_path)), f'/sample')
    test_data = pickle.load(open(config.test_path, 'rb'))

    test_loader = DataLoader(
        SWDataset(test_data),
        config.train.batch_size,
        num_workers=0,
        shuffle=False
    )

    model = Diffusion(config.model, **test_data).to(device)
    model.load_state_dict(ckpt['model'])

    gdrs, results = [], []
    pbar = tqdm(test_loader, dynamic_ncols=True)
    for i, (gdr, func) in enumerate(pbar):
        gdr = gdr.to(device)
        gdrs.append(gdr)
        results.append(model.sampling(func.shape, gdr.to(device)))

        with open(os.path.join(output_dir, 'samples_%d.pkl' % i), 'wb') as f:
            pickle.dump(listtensor2numpy(results), f)

        test_loss = np.round(
            sw_pde_loss(
                torch.concatenate(results, dim=0),
                model.dx, model.dy, model.dt,
                torch.concatenate(gdrs, dim=0), model.coriolis
            ),
            decimals=5
        )
        pbar.set_description(f'Test loss: {test_loss}')

    for temp_file in os.listdir(output_dir):
        os.remove(os.path.join(output_dir, temp_file))

    samples_path = os.path.join(output_dir, f'samples_all_pde_{test_loss}.pkl')
    with open(samples_path, 'wb') as f:
        pickle.dump(listtensor2numpy(results), f)

    with open(os.path.join(output_dir, f'gdrs.pkl'), 'wb') as f:
        pickle.dump(listtensor2numpy(gdrs), f)

    return samples_path
        
if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn') 

    model_path = 'logs/pde_1.0---hidden_size-16---2024_08_02__20_16_09/checkpoints/ckpt.pt'

    sampling_eval(model_path)
