"""
ARIMA Synthetic Dataset  
- Generate data from an ARIMA(p, d, q) process  
- Internally, generarates data using an ARMA(p + d, q) process with d unit roots.  
"""
from bisect import bisect
from types import SimpleNamespace
from pytz import all_timezones_set
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from functools import partial 

import matplotlib.pyplot as plt

from dataloaders.datasets import SequenceDataset
from dataloaders.forecasting import TimeSeriesHelper, TimeSeriesDataset

try:
    import statsmodels

    # Exponential Smoothing Methods
    from statsmodels.tsa.statespace.exponential_smoothing import ExponentialSmoothing
    from statsmodels.tsa.holtwinters import ExponentialSmoothing as ExponentialSmoothingHoltWinters
    # from statsmodels.tsa.arima_process import arma_generate_sample

except ImportError:
    pass

from scipy import signal

TIME_FEATURES = {
    'second',
    'minute',
    'hour',
    'day',
    'week',
    'month',
    'quarter',
    'year',
}

def arma_generate_sample(ar, ma, nsample, scale=1, distrvs=None, axis=0, burnin=0,
                         initial_x=0):
    """
    Custom function to simulate data from ARMA, starting with an initial value (default = 0).
    """
    distrvs = np.random.standard_normal if distrvs is None else distrvs
    if np.ndim(nsample) == 0:
        nsample = [nsample]
    if burnin:
        # handle burin time for nd arrays
        # maybe there is a better trick in scipy.fft code
        newsize = list(nsample)
        newsize[axis] += burnin
        newsize = tuple(newsize)
        fslice = [slice(None)] * len(newsize)
        fslice[axis] = slice(burnin, None, None)
        fslice = tuple(fslice)
    else:
        newsize = tuple(nsample)
        fslice = tuple([slice(None)] * np.ndim(newsize))
    eta = scale * distrvs(size=newsize)
    eta[0] = initial_x
    return signal.lfilter(ma, ar, eta, axis=axis)[fslice]


# Base Class
class ForecastingSynthetic(Dataset):
    """
    Base time series forecasting class  
    """
    def __init__(self, n_ts, nobs_per_ts, seed):
        super().__init__()

        self.n_ts = n_ts
        self.nobs_per_ts = nobs_per_ts
        self.seed = seed

        np.random.seed(seed)
        self._setup_process()
        self.ts = self.generate()

    def _setup_process(self, *args, **kwargs):
        raise NotImplementedError

    def generate(self, *args, **kwargs):
        raise NotImplementedError

    def __len__(self):
        return self.n_ts

    def __getitem__(self, idx):
        return self.ts[idx]  
    
    
    
class ARIMASynthetic(ForecastingSynthetic):
    """
    Generate data from an ARIMA(p, d, q) process.
    Internally, generarates data using an ARMA(p + d, q) process with d unit roots.
    """

    def __init__(
        self, 
        p, 
        d,
        q, 
        n_ts, 
        nobs_per_ts, 
        c=0,
        initial_x=0,
        scale=1.,
        seed=42,
        verbose=False
    ):
        # print("Init called")
        self.p = p
        self.d = d
        self.q = q
        self.c = c # constant offset term in the ARIMA equation 
        self.initial_x = initial_x  # first value
        self.scale = scale
        self.verbose = verbose

        super().__init__(n_ts, nobs_per_ts, seed=seed)
    
    @staticmethod
    def _sample_complex_unit_circle(n):
        r = np.sqrt(np.random.rand(n))
        theta = np.random.rand(n) * 2 * np.pi
        return r * np.cos(theta) + 1j * r * np.sin(theta)

    def _setup_process(self):
        np.random.seed(self.seed)

        # Construct complex-conjugate roots inside the unit circle for the ARIMA process
        # Both the AR / MA characteristic polynomials should satisfy this
        ar_roots = self._sample_complex_unit_circle(self.p // 2)
        ma_roots = self._sample_complex_unit_circle(self.q // 2)

        # Add unit roots as ARIMA(p, d, q) = ARMA(p + d, q)
        unit_roots = [1.0] * self.d

        print("Constructing ARIMA(%d, %d, %d) process..." % 
              (self.p, self.d, self.q))
        print("Unit roots (multiplicity):", self.d)
        if self.verbose:
            print("AR roots:", ar_roots)
        print("MA roots:", ma_roots)

        if self.p % 2 == 0:
            # Just keep the complex roots and add in the unit roots
            ar_roots = np.r_[ar_roots, ar_roots.conj(), unit_roots]
        else:
            # Add a real root to the p - 1 complex roots, as well as the unit roots
            ar_roots = np.r_[ar_roots, ar_roots.conj(), 2 * np.random.rand(1) - 1, unit_roots]

        if self.q % 2 == 0:
            ma_roots = np.r_[ma_roots, ma_roots.conj()]
        else:
            # Add a real root to the q - 1 complex roots
            ma_roots = np.r_[ma_roots, ma_roots.conj(), np.random.rand(1)]

        # Construct the polynomial coefficients from the roots
        # Coefficients of c[0] * z^n + c[1] * z^(n-1) + ... + c[n]
        # with c[0] always equal to 1.
        ar_coeffs = np.poly(ar_roots) 
        ma_coeffs = np.poly(ma_roots)

        self.ar_params = np.r_[ar_coeffs]
        self.ma_params = np.r_[ma_coeffs]

        # print("AR coefficients:", self.arparams)
        # print("MA coefficients:", self.maparams)

    def generate(self):
        ts = []
        for _ in range(self.n_ts):
            y = arma_generate_sample(
                self.ar_params, 
                self.ma_params, 
                self.nobs_per_ts, 
                scale=1,  # self.scale,
                burnin=0,  # 100,
                # Add constant offset as mean shift, 
                # Don't scale up mean shift by noise
                distrvs=partial(np.random.normal, loc=self.c, scale=self.scale), 
                # distrvs=partial(np.random.normal, loc=0, scale=self.scale), 
                initial_x=self.initial_x
            )
            ts.append(y)
        ts = np.array(ts)

        return ts
    
    
class ARIMASyntheticDataset(SequenceDataset):
    _name_ = 'synthetic-arima'

    @property
    def d_input(self):
        return 1

    @property
    def d_output(self):
        return 1
    
    @property
    def l_output(self):
        return self.horizon

    @property
    def L(self):
        return self.horizon + self.lag

    init_defaults = {
        'p': 1,
        'd': 0,
        'q': 1,
        'n_ts': 1,
        'nobs_per_ts': 100,
        'horizon': 1,
        'lag': 1,
        'val_gap': 0,
        'test_gap': 0,
        'seed': 42,
        'c': 0,
        'initial_x': 0,
        'scale': 1.,
        'seasonal': None,
    }

    def _process_seasonality(self, ts, ts_times):
        freqs = self.seasonal.keys()
        assert all([freq in FREQS for freq in freqs]), "Invalid frequency"

        for freq in freqs:
            seasonal_process = ARIMASynthetic(
                p=self.seasonal[freq]['p'],
                d=self.seasonal[freq]['d'],
                q=self.seasonal[freq]['q'],
                n_ts=self.n_ts,
                nobs_per_ts=self.nobs_per_ts,
                seed=self.seasonal[freq]['seed'],
                c=self.seasonal[freq]['c'],  # changed 7/10, before self.c
                scale=self.seasonal[freq]['scale'],
            )
            seasonal_ts = seasonal_process.generate()
            seasonal_times = [TimeSeriesHelper.generate_timestamps(self.nobs_per_ts, freq) for _ in range(self.n_ts)]

            # Add values from the seasonal_ts to ts using the timestamps
            for i in range(self.n_ts):
                for j, timestamp in enumerate(ts_times[i]):
                    try:
                        ts[i][j] += seasonal_ts[i][np.where(seasonal_times[i] == timestamp)[0][0]]
                    except IndexError:
                        pass

        return ts, ts_times

    def setup(self, val_ratio=0.2, test_ratio=0.2):
        # Generate synthetic data from ARIMA(p, d, q)
        process = ARIMASynthetic(
            p=self.p, 
            d=self.d,
            q=self.q, 
            n_ts=self.n_ts, 
            nobs_per_ts=self.nobs_per_ts, 
            seed=self.seed, 
            initial_x=self.initial_x,
            c=self.c,
            scale=self.scale,
        )
        ts = process.generate()
        ts_times = [TimeSeriesHelper.generate_timestamps(self.nobs_per_ts, freq='D')] * self.n_ts

        # Add seasonal component
        if self.seasonal is not None:
            ts, ts_times = self._process_seasonality(ts, ts_times)
            
        # Try this thing
        # ts = [_ts - self.c for _ts in ts]  # <- self.c? maybe do something else?

        self.ts = ts
        
        
        
        train_ratio = 1. - val_ratio - test_ratio
        n_val = int(np.round(self.nobs_per_ts * val_ratio))
        n_test = int(np.round(self.nobs_per_ts * test_ratio))

        self.splits = TimeSeriesHelper.train_test_split_all(
            ts, self.lag, self.horizon, n_test=n_test, ts_times=ts_times, gap=self.test_gap)
        self.splits_val = TimeSeriesHelper.train_test_split_all(
            self.splits.train_ts, self.lag, self.horizon, n_test=n_val, ts_times=self.splits.train_ts_times, gap=self.val_gap)

        # Wrap the time series and their timestamps in a dataset
        self.dataset_train = TimeSeriesDataset(self.splits_val.train_ts, self.splits_val.train_ts_times, self.lag, self.horizon)
        self.dataset_val = TimeSeriesDataset(self.splits_val.test_ts, self.splits_val.test_ts_times, self.lag, self.horizon, standardization=self.dataset_train.standardization)
        self.dataset_test = TimeSeriesDataset(self.splits.test_ts, self.splits.test_ts_times, self.lag, self.horizon, standardization=self.dataset_train.standardization)
        
        # My addition - so we can look at arparams, maparams
        self.process = process

    @staticmethod
    def collate_fn(batch, resolution, **kwargs):
        x, y, *z = zip(*batch)
        x = torch.stack(x, dim=0)[:, ::resolution]
        y = torch.stack(y, dim=0)[:, ::resolution]
        time, ids = z
        time = {k: torch.stack([e[k] for e in time], dim=0)[:, ::resolution] for k in time[0].keys()}
        ids = torch.tensor(ids)
        return x, y, time, ids

    def train_dataloader(self, **kwargs):
        return super().train_dataloader(**kwargs)

    def val_dataloader(self, **kwargs):
        # Shuffle the val dataloader so we get random forecast horizons!
        kwargs['shuffle'] = True
        kwargs['drop_last'] = False
        return super().val_dataloader(**kwargs)

    def test_dataloader(self, **kwargs):
        kwargs['drop_last'] = False
        return super().test_dataloader(**kwargs)
    
    

def load_arima_data(config_dataset, config_loader, val_ratio, test_ratio):
    dataset = ARIMASyntheticDataset(**config_dataset)
    dataset.setup(val_ratio, test_ratio)
    
    train_loader = dataset.train_dataloader(**config_loader)
    # Eval loaders are dictionaries where key is resolution, value is dataloader
    # For now just set resolution to 1
    val_loader   = dataset.val_dataloader(**config_loader)[None]
    test_loader  = dataset.test_dataloader(**config_loader)[None]
    
    print(f'AR coeffs: {dataset.process.ar_params}')
    print(f'MA coeffs: {dataset.process.ma_params}')
    
    return (train_loader, val_loader, test_loader), dataset


def visualize_arima(dataloaders, splits=['train', 'val', 'test'],
                    save=False, args=None, title=None):
    assert len(splits) == len(dataloaders)
    n_ts = len(dataloaders[0].dataset.ts)
    for ts_idx in range(n_ts):
        start_idx = 0
        for idx, split in enumerate(splits):
            y = dataloaders[idx].dataset.ts[ts_idx]
            x = np.arange(len(y)) + start_idx
            label = split if ts_idx == 0 else None
            plt.plot(x, y, label=label)
            start_idx += len(x)
    title = 'ARIMA' if title is None else title
    plt.title(title)
    plt.legend()
    
    if save:
        assert args is not None
        plt.savefig(f'arima-p={args.p}-q={args.q}-d={args.d}-c={args.c}-ds={args.dataset_seed}.pdf')
    else:
        plt.show()
    
      
# Modularity - should refactor
def load_data(config_dataset, config_loader, args=None):
    return load_arima_data(config_dataset, config_loader,
                           val_ratio=0.2, test_ratio=0.2)

# Modularity - should refactor
def visualize_data(dataloaders, splits=['train', 'val', 'test'],
                   save=False, args=None, title=None):
    visualize_arima(dataloaders, splits, save, args, title)