import os
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
import pickle
jax.config.update("jax_enable_x64", True)

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )



class DatabaseJax:
    def __init__(self, args, AA, OrthogonalSystem, PL_inv=None):
        self.args = args
        self.N = args.N
        self.AA = AA
        self.PL_inv = PL_inv
        self.OrthogonalSystem = OrthogonalSystem

        self.key = args.key
        self.DB_forcingtype = args.DB_forcingtype
        # self.DB_size = args.DB_size
        self.train_size = args.train_size
        self.test_size = args.test_size
        self.DB_sin_mean = args.DB_sin_mean
        self.DB_sin_sd = args.DB_sin_sd
        self.DB_cos_mean = args.DB_cos_mean
        self.DB_cos_sd = args.DB_cos_sd

        self.dtype = jnp.float64
        self.xx = OrthogonalSystem.xx

    def database_load(self, DB_ROOT):
        args = self.args
        key = self.key
        DB_size = args.DB_size
        database_generation_1d = self.database_generation_1d

        
        logging.info(f"Generating Train Data : {DB_ROOT}")
        os.makedirs(args.DB_FOLDER, exist_ok = True)
        database_generation_1d(SAVEPATH=DB_ROOT)


        with open(DB_ROOT, 'rb') as f:
            database = pickle.load(f)

        DB = database['DB']
        DB_INFO = database['DB_info']

        # 2. Shuffle & split 

        train_size = args.train_size
        test_size = args.test_size
        total_size = train_size + test_size

        # 3. Shuffle indices with numpy
        indices = jax.random.permutation(key, total_size)

        train_indices = indices[:train_size]
        test_indices = indices[train_size:]

        # 4. Split DB
        TRAIN_DB = {
            key: val[train_indices] for key, val in DB.items()
        }
        TEST_DB = {
            key: val[test_indices] for key, val in DB.items()
        }
        return DB, DB_INFO, TRAIN_DB, TEST_DB

    def forcing_generation_1d(self, key):
        DB_forcingtype = self.DB_forcingtype
        DB_size = self.train_size + self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        xx = self.xx.reshape(1, -1)  # shape: (1, 34)
        
        NN = self.N
        dtype = self.dtype

        key1, key2, key3, key4 = jax.random.split(key, 4)

        if DB_forcingtype == 'uniform':
            h1 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key1, (DB_size, 1), dtype=dtype)
            h2 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key2, (DB_size, 1), dtype=dtype)
            m1 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key3, (DB_size, 1), dtype=dtype)
            m2 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key4, (DB_size, 1), dtype=dtype)

            # xx: (NN+1,), broadcasting: (B,1) * (NN+1,) → (B, NN+1)
            sin_term = h1 * jnp.sin(m1 * xx)
            cos_term = h2 * jnp.cos(m2 * xx)
            forcing = sin_term + cos_term  # shape: (DB_size, NN+1)

            # forcing_parameter: shape (DB_size, 2) (as in original code)
            forcing_parameter = jnp.concatenate([h1, m1], axis=-1)

        elif DB_forcingtype == 'ones':
            forcing = jnp.ones((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.tile(jnp.array([[0., 1., 0., 0.]], dtype=dtype), (DB_size, 1))

        elif DB_forcingtype == 'zeros':
            forcing = jnp.zeros((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.zeros((DB_size, 4), dtype=dtype)

        else:
            raise NotImplementedError("Wrong Forcing Type")

        return forcing, forcing_parameter


    def database_generation_1d(self, SAVEPATH):

        dtype = self.dtype
        args = self.args
        NN = self.args.N
        OrthogonalSystem = self.OrthogonalSystem
        xx = self.xx
        AA = self.AA
        key = self.key

        DB_forcingtype = self.DB_forcingtype
        train_size = self.train_size
        test_size = self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        forcing, forcing_parameter = self.forcing_generation_1d(key)

        DB_info = {}
        DB_info['xx'] = xx
        DB_info['equation'] = args.equation
        DB_info['pde_parameter'] = args.pde_parameter
        DB_info['A'] = AA
        DB_info['train_size'] = train_size
        DB_info['test_size'] = test_size
        DB_info['forcingtype'] = DB_forcingtype
        DB_info['forcing_parameter'] = forcing_parameter
        DB_info['sin_mean'] = DB_sin_mean
        DB_info['sin_sd'] = DB_sin_sd
        DB_info['cos_mean'] = DB_cos_mean
        DB_info['cos_sd'] = DB_cos_sd


        # (1) ForwardTransform_1D: (B, NN+1) → (B, NN-1, 1)
        bar_f = OrthogonalSystem.ForwardTransform_1D(forcing)  # shape: (B, NN-1, 1)
        bar_f = bar_f.reshape(-1, NN-1)  # → (B, NN-1)

        # (2) Solve AA x = bar_f.T → use vmap
        def solve_AA(b):  # b: (NN-1,)
            return jnp.linalg.solve(AA, b.reshape(NN-1, 1)).reshape(-1)  # → (NN-1,)

        coeff = jax.vmap(solve_AA)(bar_f)  # (B, NN-1)

        # (3) Reconstruct u_true
        u_true = OrthogonalSystem.reconstruct_1D(coeff.reshape(-1, 1, NN-1))  # → (B, 1, NN+1)
        u_true = u_true.reshape(-1, NN+1)  # → (B, NN+1)

        # (4) 반환
        DB = {
            'forcing': forcing,
            'bar_f':bar_f,
            'coeff': coeff,
            'u_true': u_true
        }

        if self.PL_inv is not None:
            DB['RHS'] = jnp.einsum('ij,bj->bi', self.PL_inv, bar_f) 
        else:
            DB['RHS'] = bar_f
            
        Database = {'DB': DB, 'DB_info': DB_info}

        with open(SAVEPATH, 'wb') as f:
            pickle.dump(Database, f)

        return Database
