import gc
import time
import pandas as pd
from tqdm import trange
from jax import grad, jacfwd, random, jit, partial
import jax.numpy as np
from jax.config import config
#config.update("jax_enable_x64", True)
import jax.profiler
from jax.interpreters import xla
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split


class Experiment:
    def __init__(self):
        self.d_1 = 2
        self.d_2 = 2
        self.d_3 = 2
        R_1 = np.array([
            [1., -0.2],
            [0.3, -4.]
        ])
        R_2 = np.array([
            [-2., 0.3],
            [-0.4, 1.]
        ])
        R_3 = np.array([
            [4., 0.2],
            [0.3, 1.]
        ])
        #self.Q_1 = R_1.T.dot(R_1)
        #self.Q_1 /= np.linalg.norm(self.Q_1)
        #self.Q_2 = R_2.T.dot(R_2)
        #self.Q_2 /= np.linalg.norm(self.Q_2)
        #self.Q_3 = R_3.T.dot(R_3)
        #self.Q_3 /= np.linalg.norm(self.Q_3)
        self.Q_1 = np.eye(2)
        self.Q_2 = np.eye(2)
        self.Q_3 = np.eye(2)
        self.p = np.array([0., 0.])
    
    def initialize(self, T_1, T_2, T_3, eta_1, eta_2, eta_3, note=''):
        self.X_1 = np.array([1., 1.])
        self.X_2 = np.array([-1., 0.])
        self.X_3 = np.array([0.5, -1])
        self.T_1, self.T_2, self.T_3 = T_1, T_2, T_3
        self.eta_1, self.eta_2, self.eta_3 = eta_1, eta_2, eta_3
        self.result_cols = [
            't_1', 't_2', 't_3',
            'elapsed_time', 'updated_var',
            'x_1', 'x_2', 'x_3',
            'f_1', 'f_2', 'f_3']
        self.result = pd.DataFrame(columns=self.result_cols)
        self.time_start = int(time.time())
        with open(f'result/three_points/log.txt', mode='a') as log:
            log.writelines([
                f'timestamp: {self.time_start}\n',
                f'T_2: {self.T_2}, T_3: {self.T_3}\n',
                f'eta_1: {eta_1}, eta_2: {eta_2}, eta_3: {eta_3}\n',
                f'x_1: {self.X_1}, x_2: {self.X_2}, x_3: {self.X_3}\n',
                f'Q_1: {self.Q_1}\n', f'Q_2: {self.Q_2}\n', f'Q_3: {self.Q_3}\n',
                f'p: {self.p}\n',
                f'{note}\n',
                f'------------------------------\n'
            ])
        self.update_result(0, 0, 0, None)
        
    def update_result(self, t_1, t_2, t_3, updated_var):
        time_now = time.time()
        time_elapsed = time_now - self.time_start
        result_row = pd.DataFrame([[
            t_1, t_2, t_3,
            time_elapsed, updated_var,
            self.X_1.tolist(), self.X_2.tolist(), self.X_3.tolist(),
            self.f_1(self.X_1, self.X_2, self.X_3),
            self.f_2(self.X_1, self.X_2, self.X_3),
            self.f_3(self.X_1, self.X_2, self.X_3),
        ]], columns=self.result_cols)
        self.result = self.result.append(result_row)

    def f_1(self, x_1, x_2, x_3):
        x_13 = x_1 - x_3
        dist_13 = x_13.dot(self.Q_1).dot(x_13)
        dist_1p = np.linalg.norm(x_1 - self.p) ** 2
        return dist_13 + dist_1p
    
    def f_2(self, x_1, x_2, x_3):
        x_21 = x_2 - x_1
        dist_21 = x_21.dot(self.Q_2).dot(x_21)
        return dist_21

    def f_3(self, x_1, x_2, x_3):
        x_32 = x_3 - x_2
        dist_32 = x_32.dot(self.Q_3).dot(x_32)
        return dist_32
    

    def g_3(self, x_1, x_2, x_3):
        g = jacfwd(self.f_3, 2)(x_1, x_2, x_3)
        return g

    def phi_3(self, x_1, x_2, x_3):
        g = self.g_3(x_1, x_2, x_3)
        return x_3 - self.eta_3 * g
        #H_3 = jacfwd(jacfwd(self.f_3, 2), 2)(x_1, x_2, x_3)
        #H_3_inv = np.linalg.inv(H_3)
        #return x_3 - np.dot(H_3_inv, g)

    def g_2(self, x_1, x_2, x_3, t_1, t_2, log=True):
        Z_3 = np.zeros([self.d_3, self.d_2])
        x_3_ = x_3
        for t_3 in trange(self.T_3, desc='x_3 in g_2', leave=False):
            # x_1に関する微分；B_3を求める
            J = jacfwd(self.phi_3, [1, 2])(x_1, x_2, x_3_)
            x_3_ = self.phi_3(x_1, x_2, x_3_)
            self.X_3 = x_3_
            if log:
                self.update_result(t_1, t_2, t_3+1, 3)
            B_3 = J[0]
            A_3 = J[1]
            # x_3に関する微分；A_2を求める
            Z_3 = np.dot(A_3, Z_3) + B_3
        g = jacfwd(self.f_2, 1)(x_1, x_2, x_3_)
        g += np.dot(jacfwd(self.f_2, 2)(x_1, x_2, x_3_), Z_3)
        return g

    def phi_2(self, x_1, x_2, x_3, t_1, t_2, log=True):
        g = self.g_2(x_1, x_2, x_3, t_1, t_2, log)
        return x_2 - self.eta_2 * g

    def g_1(self, x_1, x_2, x_3, t_1, log=True):
        # k = 2
        Z_2 = np.zeros([self.d_2, self.d_1])
        x_2_ = x_2
        x_3_ = x_3
        for t_2 in trange(self.T_2, desc='x_2 in g_1', leave=False):
            # x_1に関する微分；B_2を求める
            # x_2に関する微分；A_2を求める
            J = jacfwd(self.phi_2, [0, 1])(
                x_1, x_2_, x_3_, t_1, t_2, log=False)
            x_2_ = self.phi_2(x_1, x_2_, x_3_, t_1, t_2, log=log)
            self.X_2 = x_2_
            if log:
                self.update_result(t_1, t_2+1, 0, 2)
            B_2 = J[0]
            A_2 = J[1]
            Z_2 = np.dot(A_2, Z_2) + B_2
            x_3_ = self.X_3
            if t_2 % max(int(self.T_2 / 10), 1) == 0:
                xla._xla_callable.cache_clear()
        # k = 3
        Z_3 = np.zeros([self.d_3, self.d_1])
        for t_3 in trange(self.T_3, desc='x_3 in g_1', leave=False):
            # x_1に関する微分；B_3を求める
            # x_3に関する微分；A_3を求める
            J = jacfwd(self.phi_3, [0, 1, 2])(x_1, x_2_, x_3_)
            x_3_ = self.phi_3(x_1, x_2_, x_3_)
            self.X_3 = x_3_
            if log:
                self.update_result(t_1, self.T_2, t_3+1, 3)
            B_3 = J[0]
            C_32 = J[1]
            A_3 = J[2]
            B_3 += np.dot(C_32, Z_2)
            Z_3 = np.dot(A_3, Z_3) + B_3
        g = jacfwd(self.f_1, 0)(x_1, x_2_, x_3_)
        g += np.dot(jacfwd(self.f_1, 1)(x_1, x_2_, x_3_), Z_2)
        g += np.dot(jacfwd(self.f_1, 2)(x_1, x_2_, x_3_), Z_3)
        return g

    def phi_1(self, x_1, x_2, x_3, t_1, log=True, print_grad=False):
        g = self.g_1(x_1, x_2, x_3, t_1, log)
        if print_grad:
            print(f"t_1: {t_1}, gradient of f_1 w.r.t. x_1: {g}")
        return x_1 - self.eta_1 * g
    
    def optimize(self, save=True, log=True):
        for t_1 in trange(self.T_1, desc='x_1', leave=False):
            self.X_1 = self.phi_1(self.X_1, self.X_2, self.X_3, t_1, log=log)
            if log:
                self.update_result(t_1+1, 0, 0, 1)
            xla._xla_callable.cache_clear()
            gc.collect()
            #process = psutil.Process(os.getpid())
            #print(f't_1: {t_1+1}')
            #print(f'memory used: {process.memory_info().rss}')
            if save and t_1 % 10 == 0: #and (t_1 % max(int(self.T_1 / 10), 1)) == 0:
                self.save_result()
        for t_2 in trange(self.T_2, desc='x_2', leave=False):
            self.X_2 = self.phi_2(
                self.X_1, self.X_2, self.X_3, self.T_1, t_2, log=log)
            if log:
                self.update_result(self.T_1, t_2+1, 0, 2)
            if t_2 % max(int(self.T_2 / 10), 1) == 0:
                xla._xla_callable.cache_clear()
        for t_3 in trange(self.T_3, desc='x_3', leave=False):
            self.X_3 = self.phi_3(self.X_1, self.X_2, self.X_3)
            if log:
                self.update_result(self.T_1, self.T_2, t_3+1, 3)
        if save:
            self.save_result()
    
    def run(
            self,
            T_1, T_2, T_3, eta_1, eta_2, eta_3,
            save=True, log=True, note=''):
        self.initialize(T_1, T_2, T_3, eta_1, eta_2, eta_3, note)
        print(self.time_start)
        self.optimize(save, log)
        #jax.profiler.save_device_memory_profile("memory.prof")

    def save_result(self):
        self.result.to_csv(
            f'result/three_points/{self.time_start}.csv', index=False)
