"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import ABC, abstractmethod
import os
import pandas as pd
import torch
from torch import nn, Tensor
from torch.nn.utils.rnn import pack_sequence
import torchsde
from tqdm import tqdm



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def generate_dataset(path: str, dataset_params):
    dataset = load_dataset(dataset_params)
    dataset.generate(path)



def load_dataset(dataset_params):
    if dataset_params['name'] == 'Lorenz63':
        return Lorenz63(**dataset_params)
    if dataset_params['name'] == 'Exchange':
        return Exchange(**dataset_params)
    if dataset_params['name'] == 'Weather':
        return Weather(**dataset_params)
    if dataset_params['name'] == 'Solar':
        return Solar(**dataset_params)
    else:
        raise NotImplementedError("The specified system is not supported.")



class Dataset(ABC):
    @property
    def dim(self):
        return self._dim
    

    @property
    def name(self):
        return self._name
    
    
    @abstractmethod
    def generate(self, path: str):
        """
        Generate a dataset.

        Parameters
        ----------
        path: str
            Path to save the dataset.
        """
        pass



class Lorenz63(Dataset):
    def __init__(self, name: str, n_scenario: int, ti: float, tf: float, dt: float,
                 sigma: float=10, rho: float=28, beta: float=8/3, noise: float=1.0, obs_noise: float=0.0,
                 val_ratio: float=0.1, test_ratio: float=0.2, normalize: bool=True, batch_size: int=128) -> None:
        self._name = name
        self._dim = 3
        self.n_scenario = n_scenario
        self.ti = ti
        self.tf = tf
        self.dt = dt
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
        self.noise = noise
        self.obs_noise = obs_noise
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.normalize = normalize
        self.batch_size = batch_size
    

    class SDE(nn.Module):
        def __init__(self, sigma, rho, beta, noise):
            super().__init__()
            self.noise_type = 'diagonal'
            self.sde_type = 'ito'
            self.sigma = sigma
            self.rho = rho
            self.beta = beta
            self.noise = noise


        def f(self, t, x):
            drift = torch.zeros(x.shape).to(device)

            drift[...,0] = self.sigma * (x[...,1] - x[...,0])
            drift[...,1] = x[...,0] * (self.rho - x[...,2]) - x[...,1]
            drift[...,2] = x[...,0] * x[...,1] - self.beta * x[...,2]
            return drift
        

        def g(self, t, x):
            diffusion = torch.zeros(x.shape).to(device)
            diffusion[...,0] = self.noise
            diffusion[...,1] = self.noise
            diffusion[...,2] = self.noise
            return diffusion
    

    def solve(self, batch_size: int):
        TMAX = 18   # Maximum integrable time of sdeint.
        sde = self.SDE(self.sigma, self.rho, self.beta, self.noise)
        x0 = Tensor([[1, 0, 0]]*batch_size).to(device)

        # Integral interval is divided into multiple intervals to avoid RecursionError.
        for i in range(int(self.tf/TMAX)+1):
            ti_ = i*TMAX
            tf_ = min((i+1)*TMAX, self.tf)
            ts = torch.arange(ti_, tf_ + 1e-10, self.dt).to(device)
            if i == 0:
                xs = torchsde.sdeint(sde, x0, ts, method='euler')
            else:
                xs = torch.cat([xs, torchsde.sdeint(sde, xs[-1], ts, method='euler')[1:]], dim=0)
        xs = xs.transpose(0, 1)
        ts = torch.arange(0, self.tf + 1e-10, self.dt, device=device)

        xs = xs + torch.randn_like(xs) * self.obs_noise

        if self.normalize:
            x_base = (0.0, 0.0, 24.0)
            x_scale = (8.0, 8.0, 8.0)
            xs = (xs - torch.Tensor(x_base).to(device)) / torch.Tensor(x_scale).to(device)

        return xs, ts
        

    def generate(self, path: str):
        n_batches = (self.n_scenario - 1) // self.batch_size + 1
        batch_sizes = [self.batch_size] * (n_batches - 1) + [self.n_scenario - self.batch_size*(n_batches-1)]

        xs = []
        pbar = tqdm(total=n_batches, desc='Generating trajectories')
        for batch_size in batch_sizes:
            xs_, ts = self.solve(batch_size)
            xs.append(xs_)
            pbar.update(1)

        pbar.close()
        xs = torch.cat(xs, dim=0)

        idx_cut = torch.where(ts >= self.ti)[0][0]
        xs = xs[:,idx_cut:]
        ts = ts[idx_cut:]

        n_train = int(self.n_scenario * (1 - self.val_ratio - self.test_ratio))
        n_val = int(self.n_scenario * self.val_ratio)

        xs_train = xs[:n_train]
        xs_val = xs[n_train:n_train+n_val]
        xs_test = xs[n_train+n_val:]

        xs_train = pack_sequence([x for x in xs_train])
        xs_val = pack_sequence([x for x in xs_val])
        xs_test = pack_sequence([x for x in xs_test])
        ts_train = pack_sequence([ts[:,None]]*xs_train.batch_sizes[0])
        ts_val = pack_sequence([ts[:,None]]*xs_val.batch_sizes[0])
        ts_test = pack_sequence([ts[:,None]]*xs_test.batch_sizes[0])

        torch.save(xs_train, os.path.join(path, 'train_x.pt'))
        torch.save(ts_train, os.path.join(path, 'train_t.pt'))
        torch.save(xs_val, os.path.join(path, 'val_x.pt'))
        torch.save(ts_val, os.path.join(path, 'val_t.pt'))
        torch.save(xs_test, os.path.join(path, 'test_x.pt'))
        torch.save(ts_test, os.path.join(path, 'test_t.pt'))



class Exchange(Dataset):
    def __init__(self, name: str, val_ratio: float=1/3, test_ratio: float=1/3, covert_real: bool=True, normalize: bool=True) -> None:
        self._name = name
        self._dim = 8
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.covert_real = covert_real
        self.normalize = normalize

    
    def generate(self, path: str):
        data = pd.read_csv(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'exchange_rate.txt'))
        xs = data.to_numpy()
        xs = torch.Tensor(xs).unsqueeze(0)
        ts = torch.arange(len(xs[0]), dtype=torch.float32)/30.0 # unit: 30 days

        if self.covert_real:
            xs = torch.log(xs)
        if self.normalize:
            xs = (xs - xs.mean(dim=(0, 1))) / xs.std(dim=(0, 1))

        indices = torch.randperm(len(ts))

        n_train = int(len(ts) * (1 - self.val_ratio - self.test_ratio))
        n_val = int(len(ts) * self.val_ratio)
        id_train = indices[:n_train].sort()[0]
        id_val = indices[n_train:n_train+n_val].sort()[0]
        id_test = indices[n_train+n_val:].sort()[0]

        xs_train = xs[:,id_train]
        ts_train = ts[id_train]
        xs_val = xs[:,id_val]
        ts_val = ts[id_val]
        xs_test = xs[:,id_test]
        ts_test = ts[id_test]

        xs_train = pack_sequence([x for x in xs_train])
        ts_train = pack_sequence([ts_train[:,None]]*xs_train.batch_sizes[0])
        xs_val = pack_sequence([x for x in xs_val])
        ts_val = pack_sequence([ts_val[:,None]]*xs_val.batch_sizes[0])
        xs_test = pack_sequence([x for x in xs_test])
        ts_test = pack_sequence([ts_test[:,None]]*xs_test.batch_sizes[0])

        torch.save(xs_train, os.path.join(path, 'train_x.pt'))
        torch.save(ts_train, os.path.join(path, 'train_t.pt'))
        torch.save(xs_val, os.path.join(path, 'val_x.pt'))
        torch.save(ts_val, os.path.join(path, 'val_t.pt'))
        torch.save(xs_test, os.path.join(path, 'test_x.pt'))
        torch.save(ts_test, os.path.join(path, 'test_t.pt'))



class Weather(Dataset):
    def __init__(self, name: str, val_ratio: float=1/3, test_ratio: float=1/3, normalize: bool=True) -> None:
        self._name = name
        self._dim = 21
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.normalize = normalize

    
    def generate(self, path: str):
        data = pd.read_csv(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'mpi_roof_2023.csv'), parse_dates=[0], dayfirst=True)
        ts = data['Date Time'].to_numpy()
        ts = pd.to_datetime(ts).astype('int64') / 10**9 / 86400 * 6 # unit: 4 hours
        ts = ts - ts[0]
        ts = torch.Tensor(ts)
        data.drop(columns=['Date Time'], inplace=True)
        xs = data.to_numpy()
        xs = torch.Tensor(xs).unsqueeze(0)

        if self.normalize:
            xs = (xs - xs.mean(dim=(0, 1))) / xs.std(dim=(0, 1))
            
        indices = torch.randperm(len(ts))

        n_train = int(len(ts) * (1 - self.val_ratio - self.test_ratio))
        n_val = int(len(ts) * self.val_ratio)
        id_train = indices[:n_train].sort()[0]
        id_val = indices[n_train:n_train+n_val].sort()[0]
        id_test = indices[n_train+n_val:].sort()[0]

        xs_train = xs[:,id_train]
        ts_train = ts[id_train]
        xs_val = xs[:,id_val]
        ts_val = ts[id_val]
        xs_test = xs[:,id_test]
        ts_test = ts[id_test]

        xs_train = pack_sequence([x for x in xs_train])
        ts_train = pack_sequence([ts_train[:,None]]*xs_train.batch_sizes[0])
        xs_val = pack_sequence([x for x in xs_val])
        ts_val = pack_sequence([ts_val[:,None]]*xs_val.batch_sizes[0])
        xs_test = pack_sequence([x for x in xs_test])
        ts_test = pack_sequence([ts_test[:,None]]*xs_test.batch_sizes[0])

        torch.save(xs_train, os.path.join(path, 'train_x.pt'))
        torch.save(ts_train, os.path.join(path, 'train_t.pt'))
        torch.save(xs_val, os.path.join(path, 'val_x.pt'))
        torch.save(ts_val, os.path.join(path, 'val_t.pt'))
        torch.save(xs_test, os.path.join(path, 'test_x.pt'))
        torch.save(ts_test, os.path.join(path, 'test_t.pt'))



class Solar(Dataset):
    def __init__(self, name: str, val_ratio: float=1/3, test_ratio: float=1/3, convert_real: bool=True, normalize: bool=True) -> None:
        self._name = name
        self._dim = 137
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.convert_real = convert_real
        self.normalize = normalize

    
    def generate(self, path: str):
        data = pd.read_csv(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'solar_AL.txt'))
        xs = data.to_numpy()
        xs = torch.Tensor(xs).unsqueeze(0)
        ts = torch.arange(len(xs[0]), dtype=torch.float32) / 24.0 # unit: 4 hours

        if self.convert_real:
            for i in range(xs.shape[2]):
                is_zero = xs[:,:,i] == 0.0
                is_nonzero = xs[:,:,i] != 0.0
                n_zero = is_zero.sum()
                beta_dist = torch.distributions.Beta(2, 4)
                xs[:,:,i][is_zero] = beta_dist.sample((n_zero,)) * xs[:,:,i][is_nonzero].min()
                xs[:,:,i] = torch.log(torch.exp(xs[:,:,i].to(torch.float64)) - 1).to(torch.float32)
        if self.normalize:
            xs = (xs - xs.mean(dim=(0, 1))) / xs.std(dim=(0, 1))

        indices = torch.randperm(len(ts))

        n_train = int(len(ts) * (1 - self.val_ratio - self.test_ratio))
        n_val = int(len(ts) * self.val_ratio)
        id_train = indices[:n_train].sort()[0]
        id_val = indices[n_train:n_train+n_val].sort()[0]
        id_test = indices[n_train+n_val:].sort()[0]

        xs_train = xs[:,id_train]
        ts_train = ts[id_train]
        xs_val = xs[:,id_val]
        ts_val = ts[id_val]
        xs_test = xs[:,id_test]
        ts_test = ts[id_test]

        xs_train = pack_sequence([x for x in xs_train])
        ts_train = pack_sequence([ts_train[:,None]]*xs_train.batch_sizes[0])
        xs_val = pack_sequence([x for x in xs_val])
        ts_val = pack_sequence([ts_val[:,None]]*xs_val.batch_sizes[0])
        xs_test = pack_sequence([x for x in xs_test])
        ts_test = pack_sequence([ts_test[:,None]]*xs_test.batch_sizes[0])

        torch.save(xs_train, os.path.join(path, 'train_x.pt'))
        torch.save(ts_train, os.path.join(path, 'train_t.pt'))
        torch.save(xs_val, os.path.join(path, 'val_x.pt'))
        torch.save(ts_val, os.path.join(path, 'val_t.pt'))
        torch.save(xs_test, os.path.join(path, 'test_x.pt'))
        torch.save(ts_test, os.path.join(path, 'test_t.pt'))
