import gc
import time
import pandas as pd
from tqdm import trange
from jax import grad, jacfwd, random, jit, partial
import jax.numpy as np
from jax.config import config
config.update("jax_debug_nans", True)
#config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
import jax.profiler
from jax.interpreters import xla
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from lib import pseudo_l1

data_name = 'diabetes_mlp_dad'

class Experiment:
    def __init__(
            self,
            seed=100,
            train_size=40, validation_size=100):
        self.seed = seed
        data = load_diabetes()
        X = data.data
        y = data.target.reshape(-1, 1)
        y_normal = (y - np.mean(y)) / np.std(y)
        (self.X_train, X_valtest,
         self.y_train, y_valtest) = train_test_split(
            X, y_normal, train_size=train_size, random_state=seed)
        (self.X_val, self.X_test,
         self.y_val, self.y_test) = train_test_split(
            X_valtest, y_valtest, train_size=validation_size, random_state=seed)
        
        self.n_features = X.shape[1]
        #self.n_features = 5 
        #pca = PCA(n_components=self.n_features)
        #self.X_train = pca.fit_transform(self.X_train)
        #self.X_val = pca.transform(self.X_val)
        #self.X_test = pca.transform(self.X_test)
        
        self.n_hidden_units = 3  # number of units in 2nd layer of MLP
        
        # W_1 is the weight matrix from 1st layer to 2nd layer
        self.shape_W_1 = [(self.n_features + 1), self.n_hidden_units]
        self.len_W_1 = (self.n_features + 1) * self.n_hidden_units
        # W_2 is the weight matrix from 2nd layer to output
        self.shape_W_2 = [self.n_hidden_units, 1]
        self.len_W_2 = self.n_hidden_units * 1
        
        self.train_size = len(self.X_train)
        self.validation_size = len(self.X_val)
        self.test_size = len(self.X_test)
        
        # Dimentions of X_1, X_2, X_3
        self.d_1 = 1 #2
        self.d_2 = int(self.train_size * self.n_features)
        self.d_3 = self.len_W_1 + self.len_W_2 + 1
    
    def initialize(
            self,
            T_1, T_2, T_3, eta_1, eta_2, eta_3,
            noise_weight=100., note='pseudo_l1'):
        key = random.PRNGKey(self.seed)
        key, key_W_1 = random.split(key)
        key_X_3, key_W_2 = random.split(key)
        W_1 = random.normal(
            key_W_1, self.shape_W_1) / np.sqrt(self.n_features + 1)
        W_2 = random.normal(
            key_W_2, self.shape_W_2) / np.sqrt(self.n_hidden_units)
        b = 0
        self.X_1 = np.zeros([self.d_1])  # Hyperparameter
        self.X_2 = random.normal(key_W_2, [self.d_2])  # Noise to poison training data
        self.X_3 = np.hstack([W_1.reshape(-1), W_2.reshape(-1), b])  # Parameter of the MLP
        self.T_1, self.T_2, self.T_3 = T_1, T_2, T_3
        self.eta_1, self.eta_2, self.eta_3 = eta_1, eta_2, eta_3
        self.c = noise_weight
        self.result_cols = [
            't_1', 't_2', 't_3',
            'elapsed_time', 'updated_var',
            'x_1', 'norm_g_1',
            'norm_x_2', 'norm_x_3',
            'f_1', 'f_2', 'f_3',
            'train_error', 'validation_error', 'test_error']
        self.result = pd.DataFrame(columns=self.result_cols)
        self.time_start = int(time.time())
        with open(f'result/{data_name}/log.txt', mode='a') as log:
            log.writelines([
                f'timestamp: {self.time_start}\n',
                f'random seed: {self.seed}\n',
                f'T_2: {self.T_2}, T_3: {self.T_3}\n',
                f'train_size: {self.train_size}, ',
                f'validation_size: {self.validation_size}\n',
                f'd_2: {self.d_2}, d_3: {self.d_3}\n',
                f'eta_1: {eta_1}, eta_2: {eta_2}, eta_3: {eta_3}\n',
                f'c: {self.c}, random seed: {self.seed}\n'
                f'{note}\n'
                f'------------------------------\n'
            ])
        self.update_result(0, 0, 0, None)
        
    def update_result(
            self,
            t_1, t_2, t_3, updated_var, grad_1=None, save_params=False):
        time_now = time.time()
        time_elapsed = time_now - self.time_start
        grad_norm = (
            np.linalg.norm(grad_1) if grad_1 is not None
            else None 
        )
        result_row = pd.DataFrame([[
            t_1, t_2, t_3,
            time_elapsed, updated_var,
            self.X_1.tolist(), grad_norm,
            np.linalg.norm(self.X_2), np.linalg.norm(self.X_3),
            self.f_1(self.X_1, self.X_2, self.X_3),
            self.f_2(self.X_1, self.X_2, self.X_3),
            self.f_3(self.X_1, self.X_2, self.X_3),
            self.train_error(self.X_1, self.X_2, self.X_3),
            self.validation_error(self.X_1, self.X_2, self.X_3),
            self.test_error(self.X_1, self.X_2, self.X_3)]
        ], columns=self.result_cols)
        self.result = self.result.append(result_row)
        if save_params:
            self.save_params(t_1)
        
    def save_params(self, t_1):
        np.savez(
            f'result/{data_name}/params/{self.time_start}_{t_1}',
            x_1=self.X_1, x_2=self.X_2, x_3=self.X_3)

    def mlp(self, param_vec, X, activate=np.tanh):
        X_with_1 = np.hstack([X, np.ones([len(X), 1])])
        W_1 = param_vec[0:self.len_W_1].reshape(self.shape_W_1)
        W_2 = param_vec[self.len_W_1:self.len_W_1+self.len_W_2].reshape(self.shape_W_2)
        b = param_vec[-1]
        out_1 = activate(X_with_1.dot(W_1))
        out_2 = out_1.dot(W_2) + b
        return out_2
        
    def f_1(self, x_1, x_2, x_3):
        return self.validation_error(x_1, x_2, x_3)
    
    def f_2(self, x_1, x_2, x_3):
        return (
            -self.train_error(x_1, x_2, x_3)
            + self.c * pseudo_l1(x_2) / self.d_3)
    
    def f_3(self, x_1, x_2, x_3):
        return (
            self.train_error(x_1, x_2, x_3)
            + np.exp(x_1[0]) * pseudo_l1(x_3) / self.d_3
        )
    
    def train_error(self, x_1, x_2, x_3):
        P = x_2.reshape([self.train_size, self.n_features])
        return (
            np.linalg.norm(
                self.y_train - self.mlp(x_3, self.X_train + P)) ** 2
            / self.train_size
        )
    
    def validation_error(self, x_1, x_2, x_3):
        return (
            np.linalg.norm(
                self.y_val - self.mlp(x_3, self.X_val)) ** 2
            / self.validation_size
        )
    
    def test_error(self, x_1, x_2, x_3):
        return (
            np.linalg.norm(
                self.y_test - self.mlp(x_3, self.X_test)) ** 2
            / self.test_size
        )
    

    def g_3(self, x_1, x_2, x_3):
        g = grad(self.f_3, 2)(x_1, x_2, x_3)
        return g

    def phi_3(self, x_1, x_2, x_3):
        g = self.g_3(x_1, x_2, x_3)
        return x_3 - self.eta_3 * g

    def g_2(self, x_1, x_2, x_3, t_1, t_2, update=True):
        Z_3 = np.zeros([self.d_3, self.d_2])
        x_3_ = x_3
        save_params = False
        for t_3 in trange(self.T_3, desc='x_3 in g_2', leave=False):
            # x_1に関する微分；B_3を求める
            J = jacfwd(self.phi_3, [1, 2])(x_1, x_2, x_3_)
            x_3_ = self.phi_3(x_1, x_2, x_3_)
            if update:
                self.X_3 = x_3_
                if t_3 == self.T_3 - 1:
                    save_params = True
                self.update_result(
                    t_1, t_2, t_3+1, 3, save_params=save_params)
            B_3 = J[0]
            A_3 = J[1]
            # x_3に関する微分；A_2を求める
            Z_3 = np.dot(A_3, Z_3) + B_3
        g = grad(self.f_2, 1)(x_1, x_2, x_3_)
        g += np.dot(grad(self.f_2, 2)(x_1, x_2, x_3_), Z_3)
        return g

    def phi_2(self, x_1, x_2, x_3, t_1, t_2, update=True):
        g = self.g_2(x_1, x_2, x_3, t_1, t_2, update)
        return x_2 - self.eta_2 * g

    def g_1(self, x_1, x_2, x_3, t_1, update=True):
        # k = 2
        Z_2 = np.zeros([self.d_2, self.d_1])
        x_2_ = x_2
        x_3_ = x_3
        for t_2 in trange(self.T_2, desc='x_2 in g_1', leave=False):
            # x_1に関する微分；B_2を求める
            # x_2に関する微分；A_2を求める
            J = jacfwd(self.phi_2, [0, 1])(
                x_1, x_2_, x_3_, t_1, t_2, update=False)
            x_2_ = self.phi_2(x_1, x_2_, x_3_, t_1, t_2, update=update)
            if update:
                self.X_2 = x_2_
                self.update_result(t_1, t_2+1, 0, 2)
            B_2 = J[0]
            A_2 = J[1]
            Z_2 = np.dot(A_2, Z_2) + B_2
            x_3_ = self.X_3
            if t_2 % max(int(self.T_2 / 10), 1) == 0:
                xla._xla_callable.cache_clear()
        # k = 3
        Z_3 = np.zeros([self.d_3, self.d_1])
        for t_3 in trange(self.T_3, desc='x_3 in g_1', leave=False):
            # x_1に関する微分；B_3を求める
            # x_3に関する微分；A_3を求める
            J = jacfwd(self.phi_3, [0, 1, 2])(x_1, x_2_, x_3_)
            x_3_ = self.phi_3(x_1, x_2_, x_3_)
            self.X_3 = x_3_
            if update:
                self.update_result(t_1, self.T_2, t_3+1, 3)
            B_3 = J[0]
            C_32 = J[1]
            A_3 = J[2]
            B_3 += np.dot(C_32, Z_2)
            Z_3 = np.dot(A_3, Z_3) + B_3
        g = grad(self.f_1, 0)(x_1, x_2_, x_3_)
        g += np.dot(grad(self.f_1, 1)(x_1, x_2_, x_3_), Z_2)
        g += np.dot(grad(self.f_1, 2)(x_1, x_2_, x_3_), Z_3)
        return g

    def phi_1(self, x_1, x_2, x_3, t_1, update=True):#, print_grad=False):
        g = self.g_1(x_1, x_2, x_3, t_1, update)
        #if print_grad:
        #    print(f"t_1: {t_1}, gradient of f_1 w.r.t. x_1: {g}")
        return x_1 - self.eta_1 * g, g
    
    def optimize(self, save=True, update=True):
        for t_1 in trange(self.T_1, desc='x_1', leave=False):
            self.X_1, g = self.phi_1(
                self.X_1, self.X_2, self.X_3, t_1, update=update)
            if update:
                self.update_result(t_1+1, 0, 0, 1, g)
            xla._xla_callable.cache_clear()
            gc.collect()
            #process = psutil.Process(os.getpid())
            #print(f't_1: {t_1+1}')
            #print(f'memory used: {process.memory_info().rss}')
            if save :#and (t_1 % max(int(self.T_1 / 10), 1)) == 0:
                self.save_result()
        for t_2 in trange(self.T_2, desc='x_2', leave=False):
            self.X_2 = self.phi_2(
                self.X_1, self.X_2, self.X_3, self.T_1, t_2, update=update)
            if update:
                self.update_result(self.T_1, t_2+1, 0, 2)
            if t_2 % max(int(self.T_2 / 10), 1) == 0:
                xla._xla_callable.cache_clear()
        for t_3 in trange(self.T_3, desc='x_3', leave=False):
            self.X_3 = self.phi_3(self.X_1, self.X_2, self.X_3)
            if update:
                self.update_result(
                    self.T_1, self.T_2, t_3+1, 3, save_params=True)
        if save:
            self.save_result()
    
    def run(
            self,
            T_1, T_2, T_3, eta_1, eta_2, eta_3, noise_weight,
            save=True, update=True, note='pseudo_l1'):
        self.initialize(
            T_1, T_2, T_3, eta_1, eta_2, eta_3, noise_weight, note)
        print(self.time_start)
        self.optimize(save, update)
        #jax.profiler.save_device_memory_profile("memory.prof")

    def save_result(self):
        self.result.to_csv(
            f'result/{data_name}/{self.time_start}.csv', index=False)

