import os
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
import pickle
jax.config.update("jax_enable_x64", True)

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )



class DatabaseJax:
    def __init__(self, args, AA, OrthogonalSystem, PL_inv=None):
        self.args = args
        self.N = args.N
        self.AA = AA
        self.PL_inv=PL_inv
        self.OrthogonalSystem = OrthogonalSystem

        self.key = args.key
        self.DB_forcingtype = args.DB_forcingtype
        # self.DB_size = args.DB_size
        self.train_size = args.train_size
        self.test_size = args.test_size
        self.DB_sin_mean = args.DB_sin_mean
        self.DB_sin_sd = args.DB_sin_sd
        self.DB_cos_mean = args.DB_cos_mean
        self.DB_cos_sd = args.DB_cos_sd

        self.dtype = jnp.float64

        self.xx = OrthogonalSystem.xx
        self.X_2D, self.Y_2D = OrthogonalSystem.X_2D, OrthogonalSystem.Y_2D

    def database_load(self, DB_ROOT):
        args = self.args
        key = self.key
        # DB_size = args.DB_size

        # if os.path.isfile(DB_ROOT):
        #     logging.info(f"Train Data Exists: {DB_ROOT}")
        # else:
        logging.info(f"Generating Train Data : {DB_ROOT}")
        os.makedirs(args.DB_FOLDER, exist_ok = True)

        if args.dimension == '1D':
            print('Generate - 1D Data')
            database_generation = self.database_generation_1d
        elif args.dimension == '2D':
            print('Generate - 2D Data')
            database_generation = self.database_generation_2d
        else: 
            raise NotImplementedError('Wrong Dim')

        database_generation(SAVEPATH=DB_ROOT)


        with open(DB_ROOT, 'rb') as f:
            database = pickle.load(f)

        DB = database['DB']
        DB_INFO = database['DB_info']

        # 2. Shuffle & split 

        train_size = args.train_size
        test_size = args.test_size
        total_size = train_size + test_size

        # 3. Shuffle indices with numpy
        indices = jax.random.permutation(key, total_size)

        train_indices = indices[:train_size]
        test_indices = indices[train_size:]

        # 4. Split DB
        TRAIN_DB = {
            key: val[train_indices] for key, val in DB.items()
        }
        TEST_DB = {
            key: val[test_indices] for key, val in DB.items()
        }
        return DB, DB_INFO, TRAIN_DB, TEST_DB

    def forcing_generation_1d(self, key):
        DB_forcingtype = self.DB_forcingtype
        DB_size = self.train_size + self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        xx = self.xx.reshape(1, -1)  # shape: (1, 34)
        
        NN = self.N
        dtype = self.dtype

        # JAX는 random key를 사용해야 하므로 split
        key1, key2, key3, key4 = jax.random.split(key, 4)

        if DB_forcingtype == 'uniform':
            h1 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key1, (DB_size, 1), dtype=dtype)
            h2 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key2, (DB_size, 1), dtype=dtype)
            m1 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key3, (DB_size, 1), dtype=dtype)
            m2 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key4, (DB_size, 1), dtype=dtype)

            # xx: (NN+1,), broadcasting: (B,1) * (NN+1,) → (B, NN+1)
            sin_term = h1 * jnp.sin(m1 * xx)
            cos_term = h2 * jnp.cos(m2 * xx)
            forcing = sin_term + cos_term  # shape: (DB_size, NN+1)

            # forcing_parameter: shape (DB_size, 2) (as in original code)
            forcing_parameter = jnp.concatenate([h1, m1], axis=-1)

        elif DB_forcingtype == 'ones':
            forcing = jnp.ones((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.tile(jnp.array([[0., 1., 0., 0.]], dtype=dtype), (DB_size, 1))

        elif DB_forcingtype == 'zeros':
            forcing = jnp.zeros((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.zeros((DB_size, 4), dtype=dtype)

        else:
            raise NotImplementedError("Wrong Forcing Type")

        return forcing, forcing_parameter

    def forcing_generation_2d(self, key):
        DB_forcingtype = self.DB_forcingtype
        DB_size = self.train_size + self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        X_2D, Y_2D = self.X_2D, self.Y_2D

        NN = self.N
        dtype = self.dtype

        # JAX는 random key를 사용해야 하므로 split
        key1, key2, key3, key4 = jax.random.split(key, 4)

        if DB_forcingtype == 'uniform':
            h1 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key1, (DB_size, 1, 1), dtype=dtype)
            h2 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key2, (DB_size, 1, 1), dtype=dtype)
            m1 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key3, (DB_size, 1, 1), dtype=dtype)
            m2 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key4, (DB_size, 1, 1), dtype=dtype)

            # h1_expanded = jnp.broadcast_to(h1, (DB_size, NN+1, NN+1))
            # h2_expanded = jnp.broadcast_to(h2, (DB_size, NN+1, NN+1))
            m1_expanded = jnp.broadcast_to(m1, (DB_size, NN+1, NN+1))
            m2_expanded = jnp.broadcast_to(m2, (DB_size, NN+1, NN+1))

            # xx: (NN+1,), broadcasting: (B,1) * (N+1,N+1) → (B, N+1, N+1)
            sin_term = h1 * jnp.sin(m1_expanded * (X_2D + Y_2D)[None, :, :])
            cos_term = h2 * jnp.cos(m2_expanded * (X_2D + Y_2D)[None, :, :])
            forcing = sin_term + cos_term  # shape: (DB_size, (NN+1)**2)

            # forcing_parameter: shape (DB_size, 4)
            forcing_parameter = jnp.concatenate([h1, m1, h2, m2], axis=-1)

        elif DB_forcingtype == 'ones':
            forcing = jnp.ones((DB_size, NN + 1, NN + 1), dtype=dtype)
            forcing_parameter = jnp.tile(jnp.array([[0., 1., 0., 0.]], dtype=dtype), (DB_size, 1))

        elif DB_forcingtype == 'zeros':
            forcing = jnp.zeros((DB_size, NN + 1, NN + 1), dtype=dtype)
            forcing_parameter = jnp.zeros((DB_size, 4), dtype=dtype)

        else:
            raise NotImplementedError("Wrong Forcing Type")
        
        forcing = forcing.reshape(DB_size, (NN+1)**2)
        return forcing, forcing_parameter

    def database_generation_1d(self, SAVEPATH):

        dtype = self.dtype
        args = self.args
        NN = self.args.N
        OrthogonalSystem = self.OrthogonalSystem
        xx = self.xx
        AA = self.AA
        key = self.key

        DB_forcingtype = self.DB_forcingtype
        train_size = self.train_size
        test_size = self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        forcing, forcing_parameter = self.forcing_generation_1d(key)

        DB_info = {}
        DB_info['xx'] = xx
        DB_info['equation'] = args.equation
        DB_info['pde_parameter'] = args.pde_parameter
        DB_info['A'] = AA
        DB_info['train_size'] = train_size
        DB_info['test_size'] = test_size
        DB_info['forcingtype'] = DB_forcingtype
        DB_info['forcing_parameter'] = forcing_parameter
        DB_info['sin_mean'] = DB_sin_mean
        DB_info['sin_sd'] = DB_sin_sd
        DB_info['cos_mean'] = DB_cos_mean
        DB_info['cos_sd'] = DB_cos_sd


        # (1) ForwardTransform_1D: (B, NN+1) → (B, NN-1, 1)
        bar_f = OrthogonalSystem.ForwardTransform(forcing)  # shape: (B, NN-1, 1)

        # (2) Solve AA x = bar_f.T → use vmap
        def solve_AA(b):  # b: (NN-1,)
            return jnp.linalg.solve(AA, b)   # → (NN-1,)

        coeff = jax.vmap(solve_AA)(bar_f)  # (B, NN-1)

        # (3) Reconstruct u_true
        u_true = OrthogonalSystem.reconstruct(coeff)  # → (B, 1, NN+1)

        # (4) 반환
        DB = {
            'forcing': forcing,
            'bar_f':bar_f,
            'coeff': coeff,
            'u_true': u_true
        }
        Database = {'DB': DB, 'DB_info': DB_info}

        with open(SAVEPATH, 'wb') as f:
            pickle.dump(Database, f)

        return Database

    def database_generation_2d(self, SAVEPATH=None):
        DB_size = self.train_size + self.test_size
        dtype = self.dtype
        args = self.args
        NN = self.args.N
        OrthogonalSystem = self.OrthogonalSystem
        X_2D, Y_2D = self.X_2D, self.Y_2D
        AA = self.AA
        key = self.key

        DB_forcingtype = self.DB_forcingtype
        train_size = self.train_size
        test_size = self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        forcing, forcing_parameter = self.forcing_generation_2d(key)

        DB_info = {}
        DB_info['X_2D'] = X_2D
        DB_info['Y_2D'] = Y_2D
        DB_info['equation'] = args.equation
        DB_info['pde_parameter'] = args.pde_parameter
        DB_info['A'] = AA
        DB_info['train_size'] = train_size
        DB_info['test_size'] = test_size
        DB_info['forcingtype'] = DB_forcingtype
        DB_info['forcing_parameter'] = forcing_parameter
        DB_info['sin_mean'] = DB_sin_mean
        DB_info['sin_sd'] = DB_sin_sd
        DB_info['cos_mean'] = DB_cos_mean
        DB_info['cos_sd'] = DB_cos_sd


        # (1) ForwardTransform_1D:
        bar_f = OrthogonalSystem.ForwardTransform(forcing).reshape(DB_size, (NN-1) * (NN-1))

        # (2) Solve AA x = bar_f.T → use vmap
        def solve_AA(b):  # b: (NN-1,)
            return jnp.linalg.solve(AA, b)   # → (NN-1,)

        coeff = jax.vmap(solve_AA)(bar_f)  # (B, (NN-1)**2)

        # (3) Reconstruct u_true
        u_true = OrthogonalSystem.reconstruct(coeff).reshape(DB_size, NN+1, NN+1) 

        # (4) 반환
        DB = {
            'forcing': forcing,
            'bar_f':bar_f,
            'coeff': coeff.reshape(DB_size, NN-1, NN-1)
            , 'u_true': u_true
        }

        if self.PL_inv is not None:
            DB['RHS'] = jnp.einsum('ij,bj->bi', self.PL_inv, bar_f) 
        else:
            DB['RHS'] = bar_f

        Database = {'DB': DB, 'DB_info': DB_info}

        if SAVEPATH is not None:
            with open(SAVEPATH, 'wb') as f:
                pickle.dump(Database, f)

        return Database
