import numpy as np
from tensorflow import keras
import pandas as pd
from pathlib import Path


class RandomData:
    def draw(self):
        pass

class MixedNormal(RandomData):
    def __init__(self, n, d, prob):
        self.n = n
        self.d = d
        self.prob = prob
        self.rng = np.random.default_rng()
    def draw(self):
        return np.where((self.rng.uniform(size=self.n)>self.prob).reshape(-1,1),
                        self.rng.normal(scale=2,size=(self.n,self.d)),
                        self.rng.normal(size=(self.n,self.d)))
    def draw_n(self, n):
        return np.where((self.rng.uniform(size=n)>self.prob).reshape(-1,1),
                        self.rng.normal(scale=2,size=(n,self.d)),
                        self.rng.normal(size=(n,self.d)))

class Laplace(RandomData):
    def __init__(self, n, d):
        self.n = n
        self.d = d
        self.rng = np.random.default_rng()
    def draw(self):
        return self.rng.laplace(scale=2,size=(self.n,self.d))

class Uniform(RandomData):
    def __init__(self, n, d):
        self.n = n
        self.d = d
        self.rng = np.random.default_rng()
    def draw(self):
        return self.rng.uniform(-1,1,size=(self.n,self.d))*4

class MNIST(RandomData):
    def __init__(self, n, digit):
        self.n = n
        self.digit = digit
        self.rng = np.random.default_rng()

    def draw(self):
        (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
        x_train = np.reshape(x_train, shape=(len(x_train), x_train.shape[1] * x_train.shape[2]))
        x_test = np.reshape(x_test, shape=(len(x_test), x_test.shape[1] * x_test.shape[2]))
        x = np.vstack([x_train, x_test])
        y = np.hstack([y_train, y_test])
        df = pd.DataFrame(x)
        df["Class"] = y
        df = df[df["Class"] == self.digit].drop("Class", axis=1).reset_index(drop=True)
        return self.rng.permutation(df.values)[:self.n]

    def draw_n(self,n):
        (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
        x_train = np.reshape(x_train, shape=(len(x_train), x_train.shape[1] * x_train.shape[2]))
        x_test = np.reshape(x_test, shape=(len(x_test), x_test.shape[1] * x_test.shape[2]))
        x = np.vstack([x_train, x_test])
        y = np.hstack([y_train, y_test])
        df = pd.DataFrame(x)
        df["Class"] = y
        df = df[df["Class"] == self.digit].drop("Class", axis=1).reset_index(drop=True)
        return self.rng.permutation(df.values)[:n]
    

class HASC:
    """Simulate a change from "walking" to "staying"."""
    
    def __init__(self):
        data_dir = Path("../data/HASC2011corpus")
        self.h0_files = list(sorted(data_dir.glob("2_walk/person101/*.csv")))
        self.h1_files = list(sorted(data_dir.glob("1_stay/person101/*.csv")))

    def h0(self, num):
        return pd.read_csv(self.h0_files[num], header=None, names = ["time", "x", "y", "z"]).drop("time", axis=1)

    def h1(self, num):
        return pd.read_csv(self.h1_files[num], header=None, names = ["time", "x", "y", "z"]).drop("time", axis=1)
        
    def change(self, num, h0_offset, len_pre_change, len_post_change):
        return np.concatenate((self.h0(num).values[h0_offset:h0_offset+len_pre_change],self.h1(num).values[:len_post_change]))