import copy
import torch
import numpy as np
import pandas as pd
import pickle
from torch.utils.data import Dataset, DataLoader
import sys

isDebug = True if sys.gettrace() else False


def construct_loader(Z, U, Z_eval, U_eval, batch_size, seq_len):
    n, m = 0, 1
    m = [m] * U.shape[1]
    train_dataset = Data_ARMA(Z, U, n, m, seq_len)
    if isDebug:
        train_loader = DataLoader(train_dataset, batch_size, shuffle=False, num_workers=1, pin_memory=False,
                                  drop_last=False)
    else:
        train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=False, drop_last=False)
    val_dataset = Data_ARMA(Z_eval, U_eval, n, m, seq_len, keep_fore=False)
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False, drop_last=False)
    return train_loader, val_loader


def rest(a, b):
    if a >= 0:
        if a < b:
            return a % b
        else:
            return 0
    else:
        return 0


class StandardScaler:
    def __init__(self):
        self.mean = np.array(0.0)
        self.std = np.array(1.0)
        self.eps = 1e-6

    def fit(self, data):
        data = data.to_numpy()
        self.mean = data.mean(0)
        self.std = data.std(0)
        # np.save('scaler_mean.npy', self.mean)
        # np.save('scaler_std.npy', self.std)

    def transform(self, data, indx=None):
        mean = self.mean
        std = self.std
        if indx is None:
            return (data - mean) / (std + self.eps)
        else:
            return (data - mean[indx]) / (std[indx] + self.eps)

    def fit_transform(self, data, indx=None):
        self.fit(data)
        return self.transform(data, indx)

    def inverse_transform(self, data, indx=None):
        if data.ndim == 1:
            data = data.reshape(-1, 1)
        mean = self.mean
        std = self.std
        if indx is None:
            return (data * std) + mean
        else:
            return (data * std[indx]) + mean[indx]


def read_data(data_name='debutanizer_7i1o.csv', io='7i1o', seq_len=16):
    data_raw = pd.read_csv(data_name)
    if data_name.endswith('.csv'):
        data_name = data_name[:-4]
    in_dims = int(io.split('i')[0])
    out_dims = int(io.split('i')[-1].split('o')[0])
    N = data_raw.shape[0]
    data = data_raw.iloc[:, 1:]
    scaler = StandardScaler()

    pp = 800 + seq_len
    train_data = data.iloc[:N - pp]
    scaler.fit(train_data)
    data = scaler.transform(data)

    U = data.iloc[:N - pp, :in_dims].values
    Z = data.iloc[:N - pp, -out_dims:].values

    U_verify = data.iloc[N - pp:, :in_dims].values
    Z_verify = data.iloc[N - pp:, -out_dims:].values

    return Z, U, Z_verify, U_verify


class Data_ARMA(Dataset):
    def __init__(self, Z, U, n=0, m=[0], seq_len=1, keep_fore=False, num_nonstat=None):
        super(Data_ARMA, self).__init__()
        self.n = copy.deepcopy(n)
        self.m = copy.deepcopy(m)
        self.Z = copy.deepcopy(Z)
        self.U = copy.deepcopy(U)
        self.maxnm = max(n, max(m))
        self.seq_len = seq_len
        self.U_overlap, self.Z_overlap = self.overlap_ZU()
        self.N = self.Z_overlap.shape[1]
        assert self.N - self.maxnm > 0
        self.d = sum(m) + n
        self.residuals = [self.maxnm - mm for mm in self.m_overlap]
        self.keep_fore = keep_fore
        self.num_nonstat = num_nonstat * m[0] if num_nonstat is not None else None

    def reset_storage(self, ind=0):
        retval, labels = self.construct_seq_batch(ind)
        self.storage_val = np.roll(retval, 1, 0)[1:]
        self.storage_label = np.roll(labels, 1, 0)[1:]

    def set_start(self, ind):
        assert ind >= 0 and ind < len(self), rf'{ind} not between 0 and {len(self)}'
        self.U_overlap = self.U_overlap[:, ind:, :]
        self.Z_overlap = self.Z_overlap[:, ind:]
        self.U = self.U[ind:, :]
        self.Z = self.Z[ind:]
        self.N -= ind
        if self.keep_fore:
            self.reset_storage()

    def overlap_ZU(self):
        seq_len = self.seq_len
        Z = np.roll(self.Z, 1)
        n = self.n
        Z[0] = 0
        Z = Z.reshape(-1, 1)
        self.m_overlap = self.m + [n]
        self.U = np.concatenate([self.U, Z], axis=1)
        U_lap = [self.U]
        Z_lap = [self.Z]
        for i in range(1, seq_len):
            U_lap.append(np.roll(self.U, -i, axis=0))
            Z_lap.append(np.roll(self.Z, -i))
        # seq N inputs_dims, seq N
        if seq_len > 1:
            U_overlap = np.stack(U_lap, axis=0)[:, :-(seq_len - 1), :]
            Z_overlap = np.stack(Z_lap, axis=0)[:, :-(seq_len - 1)]
        else:
            U_overlap = np.stack(U_lap, axis=0)
            Z_overlap = np.stack(Z_lap, axis=0)
        return U_overlap, Z_overlap

    @property
    def ns_num(self):
        ns_num = sum(self.m) if self.num_nonstat is None else self.num_nonstat
        return ns_num

    @property
    def s_num(self):
        s_num = 0 if self.num_nonstat is None else sum(self.m) - self.num_nonstat
        return s_num

    @property
    def in_dims(self):
        in_dims = self.d if self.num_nonstat is None else self.num_nonstat
        return in_dims

    def __len__(self):
        return self.N - self.maxnm

    def construct_seq_batch(self, ind):
        maxnm = self.maxnm
        U_overlap = self.U_overlap
        Z_overlap = self.Z_overlap
        m_overlap = self.m_overlap
        N = self.N
        residuals = self.residuals
        assert ind + maxnm < N
        roll_U = []
        for j, res in enumerate(residuals):
            roll_U += [U_overlap[:, ind + res:ind + res + m_overlap[j], j]]
        roll_Z = Z_overlap[:, ind + maxnm - 1]
        retval = np.concatenate(roll_U, axis=1)
        labels = roll_Z.reshape(-1, 1)
        # seq M, seq 1
        return retval, labels

    def __getitem__(self, ind):
        retval, labels = self.construct_seq_batch(ind)

        if self.keep_fore:
            inno_val = retval[-1:, :]
            inno_label = labels[-1:, :]
            self.storage_val = np.concatenate([self.storage_val, inno_val], axis=0)
            self.storage_label = np.concatenate([self.storage_label, inno_label], axis=0)
            return self.storage_val, self.storage_label
        return retval, labels
