import gc
import time
import pandas as pd
from tqdm import trange
from jax import grad, jacfwd, random, jit, partial
import jax.numpy as np
from jax.config import config
#config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
import jax.profiler
from jax.interpreters import xla
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from lib import pseudo_l1

data_name = 'diabetes_mlp_dad_2level'

class Experiment:
    def __init__(
            self,
            seed=100,
            train_size=40, validation_size=100):
        self.seed = seed
        data = load_diabetes()
        X = data.data
        y = data.target.reshape(-1, 1)
        y_normal = (y - np.mean(y)) / np.std(y)
        (self.X_train, X_valtest,
         self.y_train, y_valtest) = train_test_split(
            X, y_normal, train_size=train_size, random_state=seed)
        (self.X_val, self.X_test,
         self.y_val, self.y_test) = train_test_split(
            X_valtest, y_valtest, train_size=validation_size, random_state=seed)
        
        self.n_features = X.shape[1]
        #self.n_features = 5 
        #pca = PCA(n_components=self.n_features)
        #self.X_train = pca.fit_transform(self.X_train)
        #self.X_val = pca.transform(self.X_val)
        #self.X_test = pca.transform(self.X_test)
        
        self.n_hidden_units = 3  # number of units in 2nd layer of MLP
        
        # W_1 is the weight matrix from 1st layer to 2nd layer
        self.shape_W_1 = [(self.n_features + 1), self.n_hidden_units]
        self.len_W_1 = (self.n_features + 1) * self.n_hidden_units
        # W_2 is the weight matrix from 2nd layer to output
        self.shape_W_2 = [self.n_hidden_units, 1]
        self.len_W_2 = self.n_hidden_units * 1
        
        self.train_size = len(self.X_train)
        self.validation_size = len(self.X_val)
        self.test_size = len(self.X_test)
        
        self.d_1 = 1 #2
        self.d_2 = self.len_W_1 + self.len_W_2 + 1
    
    def initialize(
            self,
            T_1, T_2, eta_1, eta_2,
            noise_weight=100., note='pseudo_l1'):
        key = random.PRNGKey(self.seed)
        key, key_W_1 = random.split(key)
        key_X_3, key_W_2 = random.split(key)
        W_1 = random.normal(
            key_W_1, self.shape_W_1) / np.sqrt(self.n_features + 1)
        W_2 = random.normal(
            key_W_2, self.shape_W_2) / np.sqrt(self.n_hidden_units)
        b = 0
        self.X_1 = np.zeros([self.d_1])
        self.X_2 = np.hstack([W_1.reshape(-1), W_2.reshape(-1), b])
        self.T_1, self.T_2 = T_1, T_2
        self.eta_1, self.eta_2= eta_1, eta_2
        self.result_cols = [
            't_1', 't_2',
            'elapsed_time', 'updated_var',
            'x_1', 'norm_g_1',
            'norm_x_2',
            'f_1', 'f_2',
            'train_error', 'validation_error', 'test_error']
        self.result = pd.DataFrame(columns=self.result_cols)
        self.time_start = int(time.time())
        with open(f'result/{data_name}/log.txt', mode='a') as log:
            log.writelines([
                f'timestamp: {self.time_start}\n',
                f'random seed: {self.seed}\n',
                f'T_2: {self.T_2}',
                f'train_size: {self.train_size}, ',
                f'validation_size: {self.validation_size}\n',
                f'd_2: {self.d_2}\n',
                f'eta_1: {eta_1}, eta_2: {eta_2}\n',
                f'random seed: {self.seed}\n',
                f'{note}\n'
                f'------------------------------\n'
            ])
        self.update_result(0, 0, None, None)
        
    def update_result(
            self,
            t_1, t_2, updated_var, grad_1=None, save_params=False):
        time_now = time.time()
        time_elapsed = time_now - self.time_start
        grad_norm = (
            np.linalg.norm(grad_1) if grad_1 is not None
            else None 
        )
        result_row = pd.DataFrame([[
            t_1, t_2,
            time_elapsed, updated_var,
            self.X_1.tolist(), grad_norm,
            np.linalg.norm(self.X_2),
            self.f_1(self.X_1, self.X_2),
            self.f_2(self.X_1, self.X_2),
            self.train_error(self.X_1, self.X_2),
            self.validation_error(self.X_1, self.X_2),
            self.test_error(self.X_1, self.X_2)]
        ], columns=self.result_cols)
        self.result = self.result.append(result_row)
        if save_params:
            self.save_params(t_1)
        
    def save_params(self, t_1):
        np.savez(
            f'result/{data_name}/params/{self.time_start}_{t_1}',
            x_1=self.X_1, x_2=self.X_2)

    def mlp(self, param_vec, X, activate=np.tanh):
        X_with_1 = np.hstack([X, np.ones([len(X), 1])])
        W_1 = param_vec[0:self.len_W_1].reshape(self.shape_W_1)
        W_2 = param_vec[self.len_W_1:self.len_W_1+self.len_W_2].reshape(self.shape_W_2)
        b = param_vec[-1]
        out_1 = activate(X_with_1.dot(W_1))
        out_2 = out_1.dot(W_2) + b
        return out_2

    def f_1(self, x_1, x_2):
        return self.validation_error(x_1, x_2)
    
    def f_2(self, x_1, x_2):
        return (
            self.train_error(x_1, x_2)
            + np.exp(x_1[0]) * pseudo_l1(x_2) / self.d_2
        )
    
    def train_error(self, x_1, x_2):
        return (
            np.linalg.norm(
                self.y_train - self.mlp(x_2, self.X_train)) ** 2
            / self.train_size
        )
    
    def validation_error(self, x_1, x_2):
        return (
            np.linalg.norm(
                self.y_val - self.mlp(x_2, self.X_val)) ** 2
            / self.validation_size
        )
    
    def test_error(self, x_1, x_2):
        return (
            np.linalg.norm(
                self.y_test - self.mlp(x_2, self.X_test)) ** 2
            / self.test_size
        )
    

    def g_2(self, x_1, x_2):
        g = grad(self.f_2, 1)(x_1, x_2)
        return g

    def phi_2(self, x_1, x_2):
        g = self.g_2(x_1, x_2)
        return x_2 - self.eta_2 * g
        #H_3 = jacfwd(jacfwd(self.f_3, 2), 2)(x_1, x_2, x_3)
        #H_3_inv = np.linalg.inv(H_3)
        #return x_3 - np.dot(H_3_inv, g)

    def g_1(self, x_1, x_2, t_1, t_2, update=True):
        Z_2 = np.zeros([self.d_2, self.d_1])
        x_2_ = x_2
        save_params = False
        for t_2 in trange(self.T_2, desc='x_2 in g_1', leave=False):
            # x_1に関する微分；B_2を求める
            J = jacfwd(self.phi_2, [0, 1])(x_1, x_2_,)
            x_2_ = self.phi_2(x_1, x_2_)
            if update:
                self.X_2 = x_2_
                if t_2 == self.T_2 - 1:
                    save_params = True
                self.update_result(
                    t_1, t_2+1, 2, save_params=save_params)
            B_2 = J[0]
            A_2 = J[1]
            # x_2に関する微分；A_2を求める
            Z_2 = np.dot(A_2, Z_2) + B_2
        g = grad(self.f_1, 0)(x_1, x_2_)
        g += np.dot(grad(self.f_1, 1)(x_1, x_2_), Z_2)
        return g

    def phi_1(self, x_1, x_2, t_1, update=True):
        g = self.g_1(x_1, x_2, t_1, update)
        return x_1 - self.eta_1 * g, g
    
    def optimize(self, save=True, update=True):
        for t_1 in trange(self.T_1, desc='x_1', leave=False):
            self.X_1, g = self.phi_1(
                self.X_1, self.X_2, t_1, update=update)
            if update:
                self.update_result(t_1+1, 0, 1, g)
            xla._xla_callable.cache_clear()
            gc.collect()
            #process = psutil.Process(os.getpid())
            #print(f't_1: {t_1+1}')
            #print(f'memory used: {process.memory_info().rss}')
            if save :#and (t_1 % max(int(self.T_1 / 10), 1)) == 0:
                self.save_result()
        for t_2 in trange(self.T_2, desc='x_2', leave=False):
            self.X_2 = self.phi_2(self.X_1, self.X_2)
            if update:
                self.update_result(self.T_1, t_2+1, 2)
            if t_2 % max(int(self.T_2 / 10), 1) == 0:
                xla._xla_callable.cache_clear()
        if save:
            self.save_result()
    
    def run(
            self,
            T_1, T_2, eta_1, eta_2,
            save=True, update=True, note='pseudo_l1'):
        self.initialize(
            T_1, T_2, eta_1, eta_2, note)
        print(self.time_start)
        self.optimize(save, update)
        #jax.profiler.save_device_memory_profile("memory.prof")

    def save_result(self):
        self.result.to_csv(
            f'result/{data_name}/{self.time_start}.csv', index=False)

