import math
from typing import Optional
from copy import deepcopy

import torch

from src.datamodules.components.generators.navier_stokes_2d import navier_stokes_2d, GaussianRF
from src.datamodules.components.ode_datamodule import ODEDataModule
from src.datamodules.components.pde_dataset import PDEDataset
from src.utils import utils, pylogger

log = pylogger.get_pylogger(__name__)


class NavierStokesDataset(PDEDataset):

    default_params = [{"f": 0.1, "visc": 8e-4, "ood": False}, 
                    {"f": 0.1, "visc": 9e-4, "ood": False}, 
                    {"f": 0.1, "visc": 1.0e-3, "ood": False}, 
                    {"f": 0.1, "visc": 1.1e-3, "ood": False}, 
                    {"f": 0.1, "visc": 1.2e-3, "ood": False}]
    use_fixed_cond_index = False # this is to make initial conditions different for each environment!

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.forcing_zero = self._get_f0(self.params_eq[0]['f'])
        self.sampler = GaussianRF(2, self.size, alpha=2.5, tau=7, device=kwargs.get('device', 'cpu'))
        self.load_dataset(self.cache_file, self.create_dataset)

    def _get_element(self, index):
        env_index = int(index // self.n_data_per_env)
        if self.use_fixed_cond_index:
            cond_index = int(index % self.n_data_per_env)
        else:
            cond_index = index # random condition each time

        t = torch.arange(0, self.t_horizon, self.dt_eval).float()
        # Create value if not available
        data = self.cache.get(index)
        if data is None:
            log.info(f'Calculating index {cond_index} of env {env_index}')
            w0 = self._get_init_cond(cond_index)
            # w0 = F.interpolate(w0.unsqueeze(0).unsqueeze(0), scale_factor=2).squeeze(0).squeeze(0)
            state, _ = navier_stokes_2d(w0, f=self._get_f0(self.params_eq[env_index]['f']), visc=self.params_eq[env_index]['visc'],
                                             T=self.t_horizon, delta_t=self.dt_int, record_steps=self.n)
            # h, w, t, nc
            state = state.permute(3, 2, 0, 1)[:, :self.n]  # nc, t, h, w
            # state = F.avg_pool2d(state, kernel_size=2, stride=2)
            state = state.numpy().squeeze(0) # t, h, w if nc == 1
            data = {'state': state, 't': t, 'env_idx': env_index, 'param': self.params_eq[env_index], 'ood': self.params_eq[env_index]['ood']}
            self.cache[index] = data
        return data

    def _get_init_cond(self, index):
        # torch.manual_seed(index if not self.test else self.max - index) # random seed
        if self.cache.get(f'init_cond_{index}') is None:
            w0 = self.sampler.sample().squeeze()
            state, _ = navier_stokes_2d(w0, f=self.forcing_zero, visc=8e-4, T=30.0,
                                             delta_t=self.dt_int, record_steps=20)
            init_cond = state[:, :, -1, 0]
            self.cache[f'init_cond_{index}'] = init_cond.numpy()
        else:
            init_cond = torch.from_numpy(self.cache[f'init_cond_{index}'])
        return init_cond
    
    def _get_f0(self, init_gain):
        tt = torch.linspace(0, 1, self.size + 1)[0:-1]
        X, Y = torch.meshgrid(tt, tt)
        return init_gain * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y)))


class NavierStokesDataModule(ODEDataModule):

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = NavierStokesDataset(**self.train_dataset_params)
        self.val_dataset = NavierStokesDataset(**self.val_dataset_params)
        self.test_dataset = NavierStokesDataset(**self.test_dataset_params)


if __name__ == '__main__':
    print("Performing unit test")

    minibatch_size = 16
    factor = 1
    size = 32
    state_c = 1
    # init_type = "normal"
    init_gain = 0.1
    method = "euler"
    tt = torch.linspace(0, 1, size + 1)[0:-1]
    X, Y = torch.meshgrid(tt, tt)

    dataset_train_params = {
        "n_data_per_env": 30, "t_horizon": 55, "dt_eval": 1, "dt_int":1e-3, "method": "euler", "size": size, "group": "train", #16, 10
        "cache_file": "navier_stokes_train.npy",
        "create_dataset": True,
        "params": [
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 6e-4},
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 7e-4},
            # {"f": 0.1, "visc": 8e-4, 'ood': False},
            {"f": 0.1, "visc": 6e-4, 'ood': True},
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 9e-4},
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 1.0e-3},
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 1.1e-3},
            # {"f": 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y))), "visc": 1.2e-3},
        ],
    }

    dataset_test_params = deepcopy(dataset_train_params)
    datamodule = NavierStokesDataModule(num_train_envs=1, num_test_envs=1, 
                    train_dataset_params=dataset_train_params, 
                    test_dataset_params=dataset_test_params, 
                    num_workers=minibatch_size, 
                    batch_size=minibatch_size)
    datamodule.setup()

    prev, cur, target, *_unused = next(iter(datamodule.train_dataloader()))
    print(prev.shape, cur.shape, target.shape)
    print("Train Dataloader lenght: {}".format(len(datamodule.train_dataloader())))

    