import jax.numpy as jnp
import jax
import optax
from jax import grad, jit
from jax.example_libraries import optimizers
from typing import Optional
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.optimize import minimize, LinearConstraint
from ..utils import NDArray, solve
try:
    import wandb
except ImportError:
    WANDB_AVAILABLE = False
    
import os, sys
if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)
    
from pi_lr.data.physics import *


def get_optimizer(opt_name: str, learning_rate: float, scheduler_name: Optional[str] = None, steps: Optional[int]=None):
    if scheduler_name is None:
        if opt_name == "adam":
            return optax.adam(learning_rate=learning_rate)
        elif opt_name == "sgd":
            return optax.sgd(learning_rate=learning_rate)
        else:
            raise ValueError(f"Invalid optimizer name: {opt_name}")
    
    else:
        if scheduler_name == "cosine":
            scheduler = optax.schedules.cosine_decay_schedule(learning_rate, steps)
        
        elif scheduler_name == "exponential":
            scheduler = optax.exponential_decay(learning_rate, transition_steps=steps, decay_rate=0.99)
            
        else:
            raise ValueError(f"Invalid scheduler name: {scheduler_name}")
        
        if opt_name == "adam":
            return optax.adam(learning_rate=scheduler)
        elif opt_name == "sgd":
            return optax.sgd(learning_rate=scheduler)
        else:
            raise ValueError(f"Invalid optimizer name: {opt_name}")

class LinearRegression:
    def __init__(
        self, 
        basis_function, 
        test_function, 
        domain: List[int], 
        dim_out_base: int, 
        dim_out_test: int, 
        dim_in: int,
        sparse: bool,
    ):
        self.basis_function = basis_function(dim_in=dim_in, dim_out=dim_out_base, domain=domain)
        self.test_function = test_function(dim_in=dim_in, dim_out=dim_out_test, domain=domain)
        self.weights = None
        self.sparse = sparse
        
    def _mse_loss(self, predictions, targets):
        return jnp.sqrt(jnp.mean((predictions - targets) * (predictions - targets)))

    def fit_exact(self, X: NDArray, y: NDArray, equation: DE, lambda_L2:float, lambda_eq: float):
        """
        Args:
            X (NDArray): input data of shape (n_samples, n_features)
            y (NDArray): target data of shape (n_samples,)
        """
        if isinstance(equation, NonlinearDE):
            raise ValueError("Exact fitting is not supported for nonlinear equations")
            
        n = len(X)
        Phi = self.basis_function(X, sparse=self.sparse)
        if lambda_eq > 0.0:
            if equation.M is None:
                D, G = equation.get_matrix(self.basis_function, self.test_function, sparse=self.sparse)
                equation.set_matrix(D, G)
            M = lambda_L2 * jnp.eye(Phi.shape[1]) + lambda_eq * equation.M
        else:
            if self.sparse:
                M = lambda_L2 * jax.experimental.sparse.eye(Phi.shape[1])
            else:
                M = lambda_L2 * jnp.eye(Phi.shape[1])
                
        weights = solve(Phi.T @ Phi + n * M, Phi.T @ y)
        self.weights = weights.squeeze()
        return self.weights
        
    def fit_exact_QP(self, X: NDArray, y: NDArray, equation: DE, lambda_eq: float):
        n = len(X)
        Phi = self.basis_function(X)
        if lambda_eq > 0.0:
            D, G = equation.get_matrix(self.basis_function, self.test_function)
            M = (D.T @ G) @ D
        else:
            M = jnp.zeros((Phi.shape[1], Phi.shape[1]))
        
        A = Phi.T @ Phi+ n * lambda_eq * M
        b = Phi.T @ y
        
        def objective(w):
            return jnp.dot(w, w)
        
        linear_constraint = LinearConstraint(A, b, b)
        # Initial guess for w
        w0 = np.zeros(A.shape[1])
        # Solve the QP problem using 'trust-constr' method
        result = minimize(objective, w0, method='trust-constr', constraints=[linear_constraint])
        self.weights = result.x
        return self.weights
            
    def fit_gradient_descent(
            self,
            X: NDArray,
            y: NDArray,
            X_val: NDArray,
            y_val: NDArray,
            lambda_L2: float,
            lambda_eq: float,
            equation: DE,
            optimizer: str = "adam",
            learning_rate: float = 1e-3,
            scheduler: Optional[str]=None,
            epochs: int = 1000,
            batch_size: int = 50,
            weight_init: str = "OLS",
            random_seed: Optional[int] = None,
            use_wandb: Optional[bool] = False,
            patience: int = 1000,
    ):
        """
        Args:
            X (NDArray): input data of shape (n_samples, n_features)
            y (NDArray): target data of shape (n_samples,)
            loss_function (type): description
            optimizer (type): description
            epochs (int, optional): number of epoch. Defaults to 1000.
            patience (int, optional): number of epochs with no improvement after which training will be stopped. Defaults to 10.
        Returns:
            type: description
        """
        key = jax.random.PRNGKey(random_seed) if random_seed is not None else jax.random.PRNGKey(0)
        
        # Initialize Weights and Biases
        if use_wandb:
            wandb.init(project="linear_regression", name="gradient_descent")
        sparse = self.sparse
        # Set up the basis functions
        Phi = self.basis_function(X, sparse=sparse)
        Phi_val = self.basis_function(X_val, sparse=sparse)
        y_val = y_val.squeeze()
        
        # Set up the loss function
        if isinstance(equation, LinearDE):
            D, G = equation.get_matrix(self.basis_function, self.test_function, sparse=sparse)
            equation.set_matrix(D, G)
        
        def loss_fn(weights, Phi, y):
            mse_loss = self._mse_loss(Phi @ weights, y)
            L2_loss = lambda_L2 * jnp.sum(weights ** 2)
            DE_loss = lambda_eq * equation.pi_loss(self.basis_function, self.test_function, weights)
            return mse_loss + L2_loss + DE_loss
        
        grad_loss_fn = jit(grad(loss_fn))
        
        # weight initialization
        if weight_init == "OLS":
            weights = self.fit_exact(X, y, equation, lambda_L2, lambda_eq=0.0)
        elif weight_init == "random":
            key1, key = jax.random.split(key)
            weights = jax.random.normal(key1, (Phi.shape[1],))
        elif weight_init == "zeros":
            weights = jnp.zeros(Phi.shape[1])
        else:
            raise ValueError(f"Invalid weight initialization method: {weight_init}")
    
        num_samples = X.shape[0]
        num_batches = num_samples // batch_size
        
        # Define the optimizer with a cosine learning rate scheduler
        optimizer = get_optimizer(optimizer, learning_rate, scheduler, epochs * num_batches)
        
        opt_state = optimizer.init(weights)
        
        best_val_loss = float('inf')
        epochs_no_improve = 0
        for epoch in tqdm(range(epochs)):
            key, subkey = jax.random.split(key)
            perm = jax.random.permutation(subkey, jnp.arange(num_samples))
            for i in range(num_batches):
                batch_indices = perm[i * batch_size : (i + 1) * batch_size]
                Phi_batch = Phi[batch_indices]
                y_batch = y[batch_indices]
                grads = grad_loss_fn(weights, Phi_batch, y_batch)
                updates, opt_state = optimizer.update(grads, opt_state)
                weights = optax.apply_updates(weights, updates)
                
            if epoch % 10 == 0:
                current_loss = loss_fn(weights, Phi, y)
                val_loss = loss_fn(weights, Phi_val, y_val)
                val_mse_loss = self._mse_loss(Phi_val @ weights, y_val)
                print(f"Epoch: {epoch}, Loss: {current_loss}, Val Loss: {val_loss} Val MSE Loss: {val_mse_loss}")
                if use_wandb:
                    wandb.log({"epoch": epoch, "loss": current_loss, "val_loss": val_loss, "val_mse_loss": val_mse_loss})
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    epochs_no_improve = 0
                    self.weights = weights
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break
                    
        if use_wandb:
            wandb.finish()

    def predict(self, X: NDArray):
        Phi = self.basis_function(X, sparse=self.sparse)
        return Phi @ self.weights
    
    def score(self, X: NDArray, y: NDArray):
        predictions = self.predict(X)
        return self._mse_loss(predictions, y.squeeze())
        
    def visualize_basis(self, X: NDArray, save_path: str = "basis.png"):
        Phi = self.basis_function(X, sparse=self.sparse)
        for i in range(Phi.shape[1]):
            if self.sparse:
                plt.plot(X, Phi[:, i].toarray())
            else:
                plt.plot(X, Phi[:, i])
        plt.savefig(save_path, dpi=300)
        
# 使用例
if __name__ == "__main__":
    import numpy as np

    # データ生成
    X = np.random.rand(100, 1)
    y = 3 * X.squeeze() + np.random.randn(100) * 0.1

    # モデルのインスタンス化
    model = LinearRegression()

    # exactな方法でフィッティング
    model.fit_exact(X, y)
    print("Exact Weights:", model.weights)

    # SGDを使用
    sgd_optimizer = optimizers.sgd(learning_rate=0.01)
    model.fit_gradient_descent(X, y, sgd_optimizer)
    print("SGD Weights:", model.weights)

    # Adamを使用
    adam_optimizer = optimizers.adam(learning_rate=0.01)
    model.fit_gradient_descent(X, y, adam_optimizer)
    print("Adam Weights:", model.weights)

    # 予測
    predictions = model.predict(X)
    print("Predictions:", predictions[:5])