import os
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
import pickle
jax.config.update("jax_enable_x64", True)

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )


class DatabaseJax:
    def __init__(self, args, OrthogonalSystem, B_matrix, C_matrix):
        self.args = args
        self.N = args.N
        self.OrthogonalSystem = OrthogonalSystem

        self.B_matrix = B_matrix
        self.C_matrix = C_matrix

        self.key = args.key
        self.DB_forcingtype = args.DB_forcingtype

        self.train_size = args.train_size
        self.test_size = args.test_size
        self.DB_sin_mean = args.DB_sin_mean
        self.DB_sin_sd = args.DB_sin_sd
        self.DB_cos_mean = args.DB_cos_mean
        self.DB_cos_sd = args.DB_cos_sd

        self.dtype = jnp.float64
        self.xx = OrthogonalSystem.xx

        DB_size = self.train_size + self.test_size 
        if args.equation == 'RD_1D':
            '''
                A = -\epsilon S + M
            '''
            # pde_parameter = 0.05 (default)
            self.pde_param = args.pde_parameter + jax.random.uniform(self.key, (DB_size, 1), dtype=self.dtype)
        elif args.equation == 'Helmholtz_1D':
            '''
                A = S + k^2 * M
            '''
            self.pde_param = args.pde_parameter + jax.random.uniform(self.key, (DB_size, 1), dtype=self.dtype)
            # pde_param = 14 + 0.1 * jax.random.uniform(key, (DB_size, 1), dtype=dtype)
            # pde_param = 14 + jnp.zeros((DB_size, 1), dtype=dtype)
        else:
            raise NotImplementedError('Wrong Equation')


    def database_load(self, DB_ROOT):
        args = self.args
        key = self.key
        DB_size = args.DB_size
        database_generation_1d = self.database_generation_1d

        logging.info(f"Generating Train Data : {DB_ROOT}")
        os.makedirs(args.DB_FOLDER, exist_ok = True)
        database_generation_1d(SAVEPATH=DB_ROOT)

        with open(DB_ROOT, 'rb') as f:
            database = pickle.load(f)


        # 2. Shuffle & split 
        DB = database['DB']
        DB_INFO = database['DB_info']

        train_size = args.train_size
        test_size = args.test_size
        total_size = train_size + test_size

        # 3. Shuffle indices with numpy
        indices = jax.random.permutation(key, total_size)
        train_indices = indices[:train_size]
        test_indices = indices[train_size:]

        # 4. Split DB
        TRAIN_DB = {
                    key: val[train_indices] for key, val in DB.items()
                    }
        TEST_DB = {
                    key: val[test_indices] for key, val in DB.items()
                    }
        return DB, DB_INFO, TRAIN_DB, TEST_DB

    def forcing_generation_1d(self, key):
        DB_forcingtype = self.DB_forcingtype
        DB_size = self.train_size + self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        xx = self.xx.reshape(1, -1)  # shape: (1, 34)
        
        NN = self.N
        dtype = self.dtype

        key1, key2, key3, key4 = jax.random.split(key, 4)

        if DB_forcingtype == 'uniform':
            h1 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key1, (DB_size, 1), dtype=dtype)
            h2 = DB_sin_mean + DB_sin_sd * jax.random.uniform(key2, (DB_size, 1), dtype=dtype)
            m1 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key3, (DB_size, 1), dtype=dtype)
            m2 = DB_cos_mean + DB_cos_sd * jax.random.uniform(key4, (DB_size, 1), dtype=dtype)

            # xx: (NN+1,), broadcasting: (B,1) * (NN+1,) → (B, NN+1)
            sin_term = h1 * jnp.sin(m1 * xx)
            cos_term = h2 * jnp.cos(m2 * xx)
            forcing = sin_term + cos_term  # shape: (DB_size, NN+1)

            # forcing_parameter: shape (DB_size, 2) (as in original code)
            forcing_parameter = jnp.concatenate([h1, m1], axis=-1)

        elif DB_forcingtype == 'ones':
            forcing = jnp.ones((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.tile(jnp.array([[0., 1., 0., 0.]], dtype=dtype), (DB_size, 1))

        elif DB_forcingtype == 'zeros':
            forcing = jnp.zeros((DB_size, NN + 1), dtype=dtype)
            forcing_parameter = jnp.zeros((DB_size, 4), dtype=dtype)

        else:
            raise NotImplementedError("Wrong Forcing Type")

        return forcing, forcing_parameter


    def database_generation_1d(self, SAVEPATH):

        dtype = self.dtype
        args = self.args
        NN = self.args.N
        OrthogonalSystem = self.OrthogonalSystem
        x_1d = self.xx
        key = self.key
        pde_param  = self.pde_param
        DB_forcingtype = self.DB_forcingtype
        train_size = self.train_size
        test_size = self.test_size
        DB_size = self.train_size + self.test_size
        DB_sin_mean = self.DB_sin_mean
        DB_sin_sd = self.DB_sin_sd
        DB_cos_mean = self.DB_cos_mean
        DB_cos_sd = self.DB_cos_sd

        B_matrix = self.B_matrix
        C_matrix = self.C_matrix

        forcing, forcing_parameter = self.forcing_generation_1d(key)


        # (1) ForwardTransform_1D: (B, NN+1) → (B, NN-1, 1)
        bar_f = OrthogonalSystem.ForwardTransform_1D(forcing)  # shape: (B, NN-1, 1)
        bar_f = bar_f.reshape(-1, NN-1)  # → (B, NN-1)

        # (2) Solve batched linear systems
        def solve_single(p, RHS):
            A_matrix = B_matrix + p * C_matrix
            return A_matrix, jnp.linalg.solve(A_matrix, RHS)

        A_matrix, coeff = jax.vmap(solve_single, in_axes=(0, 0))(pde_param.squeeze(-1), bar_f)  # (B, NN-1)
                        
        # (3) Reconstruct u_true
        u_true = OrthogonalSystem.reconstruct_1D(coeff.reshape(-1, 1, NN-1))  # → (B, 1, NN+1)
        u_true = u_true.reshape(-1, NN+1)  # → (B, NN+1)

        DB = {}
        DB['forcing'] = forcing
        DB['bar_f'] = bar_f
        DB['coeff'] = coeff
        DB['RHS'] = bar_f
        DB['pde_param'] = pde_param
        DB['A_matrix'] = A_matrix
        DB['u_true'] = u_true

        DB_info = {}
        DB_info['x_1d'] = x_1d
        DB_info['B_matrix'] = B_matrix
        DB_info['C_matrix'] = C_matrix

        DB_info['equation'] = args.equation
        DB_info['train_size'] = train_size
        DB_info['test_size'] = test_size
        DB_info['forcingtype'] = DB_forcingtype
        DB_info['forcing_parameter'] = forcing_parameter
        DB_info['sin_mean'] = DB_sin_mean
        DB_info['sin_sd'] = DB_sin_sd
        DB_info['cos_mean'] = DB_cos_mean
        DB_info['cos_sd'] = DB_cos_sd

        Database = {'DB': DB, 'DB_info': DB_info}

        with open(SAVEPATH, 'wb') as f:
            pickle.dump(Database, f)

        return Database
