import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, Linear, MSELoss
from einops import rearrange, reduce, repeat
# from laplace.lllaplace import DiagLLLaplace
from models.FNO1d import FNO1d
import utils
from probconserv import apply_constraint, get_empirical_mass_rhs, apply_ortho_constraint, apply_hardc_constraint
from copy import deepcopy

from nonlinear_projection import project_and_stats, safe_project_and_stats, project_and_stats_orth


device = "cuda" if torch.cuda.is_available() else "cpu"

# class ModifiedDiagLLLaplace(DiagLLLaplace):
#     def __init__(self, model, likelihood, *args, **kwargs):
#         super().__init__(model, likelihood, *args, **kwargs)
    
#     def _curv_closure(self, X, y, N):
#         return super()._curv_closure(X, y.reshape(-1, 1), N)

# class BayesianNO(nn.Module):
#     def __init__(self, base_model_class, base_model_params):
#         super().__init__()
#         self.base_model_class = base_model_class
#         self.base_model_params = base_model_params
#         self.loss_func = None
#         self.la = None

#         # self.la = DiagLLLaplace(base_model, likelihood='regression')
#         # self.la = ModifiedDiagLLLaplace(base_model, likelihood='regression')
    
#     def fit(self, train_loader, valid_loader, **fit_params):
#         base_model = self.base_model_class(last_layer_reshape=False, **self.base_model_params).to(device)
#         base_model.fit(train_loader, valid_loader, **fit_params)
#         base_model.last_layer_reshape=True

#         self.la = ModifiedDiagLLLaplace(base_model, likelihood='regression')
#         self.la.fit(train_loader)
#         log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
#         hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
#         for i in range(1000):
#             hyper_optimizer.zero_grad()
#             neg_marglik = -self.la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
#             neg_marglik.backward()
#             hyper_optimizer.step()

#         self.loss_func = base_model.loss_func
    
#     def forward(self, x):
#         b, p, t, _ = x.shape

#         # Returns mean and variance
#         mu, var = self.la(x)

#         # Assumes diag
#         mu = rearrange(mu, "(b p t) 1 -> b p t 1", b=b, p=p)
#         var = rearrange(var, "(b p t) 1 1 -> b p t 1", b=b, p=p)

#         return mu, var
    
#     def parameters(self):
#         return self.la.model.parameters()

#     def test(self, test_loader, **test_params):
#         batch_size = test_params.get("batch_size", 20)
#         test_l2 = 0.0
#         with torch.no_grad():
#             for batch in test_loader:
#                 x, y = batch
#                 x, y = x.to(device), y.to(device)
#                 out, _ = self(x)
#                 test_l2 += self.loss_func(out.view(batch_size, -1), y.view(batch_size, -1)).item()

#         test_l2 /= len(test_loader.dataset)
#         return {"loss": test_l2}


class EnsembleNO(nn.Module):
    def __init__(self, base_model_class, base_model_params, n_models=10):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.n_models = n_models
        # self.models_state_dict = []
        self.models = []
        self.loss_func = None

    def fit(self, train_loader, valid_loader, **fit_params):
        # train_loader should have shuffle=True
        for i in range(self.n_models):
            print("="*20 + f" Model {i} " + "=" * 20)
            base_model = self.base_model_class(**self.base_model_params).to(device)
            base_model.fit(train_loader, valid_loader, **fit_params)
            # self.models_state_dict.append(base_model.state_dict())
            base_model.eval()
            self.models.append(base_model)
        self.loss_func = base_model.loss_func

    def forward(self, x):
        # During inference only
        # if len(self.models_state_dict) == 0:
        if len(self.models) == 0:
            print("Models not trained. Use fit() function.")
            return

        out_list = []
        # for state_dict in self.models_state_dict:
        for base_model in self.models:
            # base_model = self.base_model_class(**self.base_model_params).to(device)
            # base_model.load_state_dict(state_dict)
            out = base_model(x)
            out_list.append(out)
        
        out_list = torch.stack(out_list) 

        return out_list.mean(dim=0), out_list.var(dim=0)
    
    def parameters(self):
        for base_model in self.models:
            # for p in state_dict.values():
            for p in base_model.parameters():
                yield p
                
    def finetune(self, finetune_loader, **fit_params):
        # After training only
        if len(self.models_state_dict) == 0:
            print("Models not trained. Use fit() function.")
            return

        finetuned_models_state_dict = []
        for i, state_dict in enumerate(self.models_state_dict):
            print("="*20 + f" Model {i} " + "=" * 20)
            base_model = self.base_model_class(**self.base_model_params).to(device)
            base_model.load_state_dict(state_dict)
            base_model.finetune(finetune_loader, **fit_params)
            finetuned_models_state_dict.append(base_model.state_dict())

        self.models_state_dict = finetuned_models_state_dict
    
    def test(self, test_loader, **test_params):
        batch_size = test_params.get("batch_size", 20)
        test_l2 = 0.0
        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out, _ = self(x)
                test_l2 += self.loss_func(out.view(batch_size, -1), y.view(batch_size, -1)).item()

        test_l2 /= len(test_loader.dataset)

        return {"loss": test_l2}


class MCDropoutNO(nn.Module):
    def __init__(self, base_model_class, base_model_params, dropout=0.1, n_dropouts=30):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.dropout = dropout
        self.n_dropouts = n_dropouts
        self.loss_func = None
        self.base_model = None
    
    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(dropout=self.dropout, **self.base_model_params).to(device)
        self.base_model.fit(train_loader, valid_loader, **fit_params)
        self.loss_func = self.base_model.loss_func

    def forward(self, x):
        outs_list = []
        for i in range(self.n_dropouts):
            outs = self.base_model(x)
            outs_list.append(outs)
        outs_list = torch.stack(outs_list)
        return outs_list.mean(dim=0), outs_list.var(dim=0)
    
    def parameters(self):
        return self.base_model.parameters()

    def test(self, test_loader, **test_params):
        batch_size = test_params.get("batch_size", 20)
        test_l2 = 0.0
        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out, _ = self(x)
                test_l2 += self.loss_func(out.view(batch_size, -1), y.view(batch_size, -1)).item()

        test_l2 /= len(test_loader.dataset)

        return {"loss": test_l2}


class E2EMCDropoutNO(nn.Module):
    def __init__(self, base_model_class, base_model_params, dropout=0.1, n_dropouts=30):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.dropout = dropout
        self.n_dropouts = n_dropouts
        self.loss_func = None
        self.base_model = None
    

    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
                output_var=False,
                probconserv=False,  # Disable internal constraints
                **self.base_model_params,
            ).to(device)

        self.loss_func = utils.nll_mu_var
        
        lr = fit_params.get("lr", 1e-3)
        step_size = fit_params.get("step_size", 50)
        gamma = fit_params.get("gamma", 0.5)
        epochs = fit_params.get("epochs", 200)
        tpred = fit_params.get("tpred", None)
        t = fit_params.get("t", None)
        dataset_class = fit_params.get("dataset_class", None)
        grid_train = fit_params.get("grid_train", None)

        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        best_valid_l2 = np.inf

        self.update_sigma = False
        self.project = False

        for epoch in range(epochs):
            self.train()
            train_l2 = 0
            for batch in train_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = self(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)

                if epoch >= 100:
                    if not self.project:
                        print("Projecting Now")
                        self.project = True

                    new_mu, new_std, _, _ = apply_constraint(
                                                        mu=mu[:, :, :, 0], 
                                                        std=std[:, :, :, 0], 
                                                        mass_rhs_func=mass_rhs_func, 
                                                        t=t, 
                                                        tpred=tpred, 
                                                        grid_train=grid_train, 
                                                        precis_g=np.inf,
                                                        second_deriv_alpha=None,
                                                        )
                    
                    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))

                l2 = self.loss_func(out, y)
                l2.backward()
                optimizer.step()
                train_l2 += l2.item()

            train_l2 /= len(train_loader.dataset)
            scheduler.step()
            if valid_loader is not None:
                valid_l2 = self.test(valid_loader, **fit_params)["loss"]
            else:
                valid_l2 = train_l2

            saved = ""
            if valid_l2 < best_valid_l2:
                best_valid_l2 = valid_l2
                best_model_state_dict = deepcopy(self.state_dict())
                saved = "(saved)"

            print(f"Epoch {epoch}: Train loss={train_l2:.6f}, Validation loss={valid_l2:.6f} {saved}")

        self.load_state_dict(best_model_state_dict)
        train_l2 = self.test(train_loader, **fit_params)["loss"]
        if valid_loader is not None:
            valid_l2 = self.test(valid_loader, **fit_params)["loss"]
        else:
            valid_l2 = train_l2
        print(f"Finished training with best train loss: {train_l2:.6f} and validation loss: {valid_l2:.6f}")

    def forward(self, x):
        outs_list = []
        for i in range(self.n_dropouts):
            outs = self.base_model(x)
            outs_list.append(outs)
        outs_list = torch.stack(outs_list)
        return outs_list.mean(dim=0), F.softplus(outs_list.var(dim=0))
    
    def parameters(self):
        return self.base_model.parameters()

    def test(self, test_loader, **test_params):
        self.eval()
        test_l2 = 0.0
        tpred = test_params.get("tpred", None)
        t = test_params.get("t", None)
        dataset_class = test_params.get("dataset_class", None)
        grid_train = test_params.get("grid_train", None)

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out = self(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
                
                if self.project:
                    new_mu, new_std, _, _ = apply_constraint(
                                                        mu=mu[:, :, :, 0], 
                                                        std=std[:, :, :, 0], 
                                                        mass_rhs_func=mass_rhs_func, 
                                                        t=t, 
                                                        tpred=tpred, 
                                                        grid_train=grid_train, 
                                                        precis_g=np.inf,
                                                        second_deriv_alpha=None,
                                                        )
                    
                    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))

                test_l2 += self.loss_func(out, y).item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}


class OutputVarNO(nn.Module):
    def __init__(self,
                 base_model_class,
                 base_model_params,
                 probconserv=False,
                 constraint_context=None):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.loss_func = None
        self.base_model = None
        self.probconserv = probconserv
        self.constraint_context = constraint_context or {}  # Holds t, tpred, etc.
    
    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
            output_var=True,
            probconserv=self.probconserv,
            **self.base_model_params,
            **self.constraint_context  # Pass constraint context to FNO
        ).to(device)
        self.base_model.fit(train_loader, valid_loader, **fit_params)
        self.loss_func = self.base_model.loss_func
    
    def forward(self, x):
        return self.base_model(x)

    def parameters(self):
        return self.base_model.parameters()

    def test(self, test_loader, **test_params):
        batch_size = test_params.get("batch_size", 20)
        test_l2 = 0.0
        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out, _ = self(x)
                # test_l2 += self.loss_func(out.view(batch_size, -1), y.view(batch_size, -1)).item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}
    
class SoftOutputVarNO(nn.Module):
    def __init__(self,
                 base_model_class,
                 base_model_params,
                 probconserv=False,
                 probconserv_sampling=False,
                 nonlinear_constraint="none",
                 constraint_context=None,
                 lambda_pinn = 0):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.loss_func = None
        self.base_model = None
        self.probconserv = probconserv
        self.nonlinear_constraint = nonlinear_constraint
        self.probconserv_sampling = probconserv_sampling
        self.constraint_context = constraint_context or {}
        self.lambda_pinn = lambda_pinn

    # def _apply_constraints(self, mu, std, x):
    #     t = self.constraint_context.get("t", None)
    #     tpred = self.constraint_context.get("tpred", None)
    #     grid_train = self.constraint_context.get("grid_train", None)
    #     dataset_class = self.constraint_context.get("dataset_class", None)

    #     if self.probconserv:
    #         mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
    #         mu, std, _, _ = apply_constraint(
    #             mu=mu[..., 0],
    #             std=std[..., 0],
    #             mass_rhs_func=mass_rhs_func,
    #             t=t,
    #             tpred=tpred,
    #             grid_train=grid_train,
    #             precis_g=np.inf,
    #             second_deriv_alpha=None
    #         )
    #         return mu.unsqueeze(-1), std.square().unsqueeze(-1)
    #     else:
    #         return mu, std.square()


    def test(self, test_loader, **test_params):
        self.eval()
        test_l2 = 0.0
        tpred = test_params.get("tpred", None)
        t = test_params.get("t", None)
        dataset_class = test_params.get("dataset_class", None)
        grid_train = test_params.get("grid_train", None)

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out = self(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
                
                _, _, _, mass_rhs = apply_constraint(
                                                    mu=mu[:, :, :, 0], 
                                                    std=std[:, :, :, 0], 
                                                    mass_rhs_func=mass_rhs_func, 
                                                    t=t, 
                                                    tpred=tpred, 
                                                    grid_train=grid_train, 
                                                    precis_g=np.inf,
                                                    second_deriv_alpha=None,
                                                    )
                
                pde_residual = nn.MSELoss()(get_empirical_mass_rhs(mu[:, :,  :, 0]), mass_rhs)

                test_l2 += self.loss_func(out, y).item() + self.lambda_pinn*pde_residual.item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}
    
    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
                output_var=True,
                probconserv=False,  # Disable internal constraints
                **self.base_model_params,
                **self.constraint_context
            ).to(device)

        self.loss_func = self.base_model.loss_func
        
        lr = fit_params.get("lr", 1e-3)
        step_size = fit_params.get("step_size", 50)
        gamma = fit_params.get("gamma", 0.5)
        epochs = fit_params.get("epochs", 200)
        tpred = fit_params.get("tpred", None)
        t = fit_params.get("t", None)
        dataset_class = fit_params.get("dataset_class", None)
        grid_train = fit_params.get("grid_train", None)

        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        best_valid_l2 = np.inf

        self.update_sigma = False

        for epoch in range(epochs):
            self.train()
            train_l2 = 0
            for batch in train_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = self.base_model(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)

                
                _, _, _, mass_rhs = apply_constraint(
                                                    mu=mu[:, :, :, 0], 
                                                    std=std[:, :, :, 0], 
                                                    mass_rhs_func=mass_rhs_func, 
                                                    t=t, 
                                                    tpred=tpred, 
                                                    grid_train=grid_train, 
                                                    precis_g=np.inf,
                                                    second_deriv_alpha=None,
                                                    )
                
                #print(get_empirical_mass_rhs(mu[:, :,  :, 0]).shape, mass_rhs.shape)
                pde_residual = nn.MSELoss()(get_empirical_mass_rhs(mu[:, :,  :, 0]), mass_rhs)

                # print(pde_residual)

                l2 = self.loss_func(out, y) + self.lambda_pinn*pde_residual
                l2.backward()
                optimizer.step()
                train_l2 += l2.item()

            train_l2 /= len(train_loader.dataset)
            scheduler.step()
            if valid_loader is not None:
                valid_l2 = self.test(valid_loader, **fit_params)["loss"]
            else:
                valid_l2 = train_l2

            saved = ""
            if valid_l2 < best_valid_l2:
                best_valid_l2 = valid_l2
                best_model_state_dict = deepcopy(self.state_dict())
                saved = "(saved)"

            print(f"Epoch {epoch}: Train loss={train_l2:.6f}, Validation loss={valid_l2:.6f} {saved}")

        self.load_state_dict(best_model_state_dict)
        train_l2 = self.test(train_loader, **fit_params)["loss"]
        if valid_loader is not None:
            valid_l2 = self.test(valid_loader, **fit_params)["loss"]
        else:
            valid_l2 = train_l2
        print(f"Finished training with best train loss: {train_l2:.6f} and validation loss: {valid_l2:.6f}")

    def forward(self, x):
        return self.base_model(x)

    def parameters(self):
        return self.base_model.parameters()
    

class HardOutputVarNO(nn.Module):
    def __init__(self,
                 base_model_class,
                 base_model_params,
                 probconserv=False,
                 probconserv_sampling=False,
                 nonlinear_constraint="none",
                 constraint_context=None):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.loss_func = None
        self.base_model = None
        self.probconserv = probconserv
        self.nonlinear_constraint = nonlinear_constraint
        self.probconserv_sampling = probconserv_sampling
        self.constraint_context = constraint_context or {}

    # def _apply_constraints(self, mu, std, x):
    #     t = self.constraint_context.get("t", None)
    #     tpred = self.constraint_context.get("tpred", None)
    #     grid_train = self.constraint_context.get("grid_train", None)
    #     dataset_class = self.constraint_context.get("dataset_class", None)

    #     if self.probconserv:
    #         mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
    #         mu, std, _, _ = apply_constraint(
    #             mu=mu[..., 0],
    #             std=std[..., 0],
    #             mass_rhs_func=mass_rhs_func,
    #             t=t,
    #             tpred=tpred,
    #             grid_train=grid_train,
    #             precis_g=np.inf,
    #             second_deriv_alpha=None
    #         )
    #         return mu.unsqueeze(-1), std.square().unsqueeze(-1)
    #     else:
    #         return mu, std.square()



    def test(self, test_loader, **test_params):
        self.eval()
        test_l2 = 0.0
        tpred = test_params.get("tpred", None)
        t = test_params.get("t", None)
        dataset_class = test_params.get("dataset_class", None)
        grid_train = test_params.get("grid_train", None)

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out = self(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
                
                new_mu, new_std, _, _ = apply_hardc_constraint(
                                                    mu=mu[:, :, :, 0], 
                                                    std=std[:, :, :, 0], 
                                                    mass_rhs_func=mass_rhs_func, 
                                                    t=t, 
                                                    tpred=tpred, 
                                                    grid_train=grid_train, 
                                                    precis_g=np.inf,
                                                    second_deriv_alpha=None,
                                                    )
                
                out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))

                test_l2 += self.loss_func(out, y).item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}
    
    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
                output_var=True,
                probconserv=False,  # Disable internal constraints
                **self.base_model_params,
                **self.constraint_context
            ).to(device)

        self.loss_func = self.base_model.loss_func
        
        lr = fit_params.get("lr", 1e-3)
        step_size = fit_params.get("step_size", 50)
        gamma = fit_params.get("gamma", 0.5)
        epochs = fit_params.get("epochs", 200)
        tpred = fit_params.get("tpred", None)
        t = fit_params.get("t", None)
        dataset_class = fit_params.get("dataset_class", None)
        grid_train = fit_params.get("grid_train", None)

        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        best_valid_l2 = np.inf

        self.update_sigma = False

        for epoch in range(epochs):
            self.train()
            train_l2 = 0
            for batch in train_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = self.base_model(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
                
                new_mu, new_std, _, _ = apply_hardc_constraint(
                                                    mu=mu[:, :, :, 0], 
                                                    std=std[:, :, :, 0], 
                                                    mass_rhs_func=mass_rhs_func, 
                                                    t=t, 
                                                    tpred=tpred, 
                                                    grid_train=grid_train, 
                                                    precis_g=np.inf,
                                                    second_deriv_alpha=None,
                                                    )
                
                out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))
                
                l2 = self.loss_func(out, y)
                l2.backward()
                optimizer.step()
                train_l2 += l2.item()

            train_l2 /= len(train_loader.dataset)
            scheduler.step()
            if valid_loader is not None:
                valid_l2 = self.test(valid_loader, **fit_params)["loss"]
            else:
                valid_l2 = train_l2

            saved = ""
            if valid_l2 < best_valid_l2:
                best_valid_l2 = valid_l2
                best_model_state_dict = deepcopy(self.state_dict())
                saved = "(saved)"

            print(f"Epoch {epoch}: Train loss={train_l2:.6f}, Validation loss={valid_l2:.6f} {saved}")

        self.load_state_dict(best_model_state_dict)
        train_l2 = self.test(train_loader, **fit_params)["loss"]
        if valid_loader is not None:
            valid_l2 = self.test(valid_loader, **fit_params)["loss"]
        else:
            valid_l2 = train_l2
        print(f"Finished training with best train loss: {train_l2:.6f} and validation loss: {valid_l2:.6f}")

    def forward(self, x):
        return self.base_model(x)

    def parameters(self):
        return self.base_model.parameters()
    

class HardE2EOutputVarNO(nn.Module):
    def __init__(self,
                 base_model_class,
                 base_model_params,
                 probconserv=False,
                 probconserv_sampling=False,
                 nonlinear_constraint="none",
                 constraint_context=None):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.loss_func = None
        self.base_model = None
        self.probconserv = probconserv
        self.nonlinear_constraint = nonlinear_constraint
        self.probconserv_sampling = probconserv_sampling
        self.constraint_context = constraint_context or {}

    # def _apply_constraints(self, mu, std, x):
    #     t = self.constraint_context.get("t", None)
    #     tpred = self.constraint_context.get("tpred", None)
    #     grid_train = self.constraint_context.get("grid_train", None)
    #     dataset_class = self.constraint_context.get("dataset_class", None)

    #     if self.probconserv:
    #         mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
    #         mu, std, _, _ = apply_constraint(
    #             mu=mu[..., 0],
    #             std=std[..., 0],
    #             mass_rhs_func=mass_rhs_func,
    #             t=t,
    #             tpred=tpred,
    #             grid_train=grid_train,
    #             precis_g=np.inf,
    #             second_deriv_alpha=None
    #         )
    #         return mu.unsqueeze(-1), std.square().unsqueeze(-1)
    #     else:
    #         return mu, std.square()



    def test(self, test_loader, **test_params):
        self.eval()
        test_l2 = 0.0
        tpred = test_params.get("tpred", None)
        t = test_params.get("t", None)
        dataset_class = test_params.get("dataset_class", None)
        grid_train = test_params.get("grid_train", None)

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out = self(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)

                if self.project:
                    new_mu, new_std, _, _ = apply_ortho_constraint(
                                                        mu=mu[:, :, :, 0], 
                                                        std=std[:, :, :, 0], 
                                                        mass_rhs_func=mass_rhs_func, 
                                                        t=t, 
                                                        tpred=tpred, 
                                                        grid_train=grid_train, 
                                                        precis_g=np.inf,
                                                        second_deriv_alpha=None,
                                                        )
                    
                    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))

                test_l2 += self.loss_func(out, y).item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}
    
    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
                output_var=True,
                probconserv=False,  # Disable internal constraints
                **self.base_model_params,
                **self.constraint_context
            ).to(device)

        self.loss_func = self.base_model.loss_func
        
        lr = fit_params.get("lr", 1e-3)
        step_size = fit_params.get("step_size", 50)
        gamma = fit_params.get("gamma", 0.5)
        epochs = fit_params.get("epochs", 200)
        tpred = fit_params.get("tpred", None)
        t = fit_params.get("t", None)
        dataset_class = fit_params.get("dataset_class", None)
        grid_train = fit_params.get("grid_train", None)

        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        best_valid_l2 = np.inf

        self.update_sigma = False
        self.project = False

        for epoch in range(epochs):
            self.train()
            train_l2 = 0
            for batch in train_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = self.base_model(x)

                mu, var = out
                std = torch.sqrt(var)
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)

                if epoch >= 0:
                    if not self.project:
                        print("Projecting Now")
                        self.project = True
                    new_mu, new_std, _, _ = apply_ortho_constraint(
                                                        mu=mu[:, :, :, 0], 
                                                        std=std[:, :, :, 0], 
                                                        mass_rhs_func=mass_rhs_func, 
                                                        t=t, 
                                                        tpred=tpred, 
                                                        grid_train=grid_train, 
                                                        precis_g=np.inf,
                                                        second_deriv_alpha=None,
                                                        )
                    
                    out = (new_mu.unsqueeze(-1), torch.square(new_std).unsqueeze(-1))
                    
                l2 = self.loss_func(out, y)
                l2.backward()
                optimizer.step()
                train_l2 += l2.item()

            train_l2 /= len(train_loader.dataset)
            scheduler.step()
            if valid_loader is not None:
                valid_l2 = self.test(valid_loader, **fit_params)["loss"]
            else:
                valid_l2 = train_l2

            saved = ""
            if valid_l2 < best_valid_l2:
                best_valid_l2 = valid_l2
                best_model_state_dict = deepcopy(self.state_dict())
                saved = "(saved)"

            print(f"Epoch {epoch}: Train loss={train_l2:.6f}, Validation loss={valid_l2:.6f} {saved}")

        self.load_state_dict(best_model_state_dict)
        train_l2 = self.test(train_loader, **fit_params)["loss"]
        if valid_loader is not None:
            valid_l2 = self.test(valid_loader, **fit_params)["loss"]
        else:
            valid_l2 = train_l2
        print(f"Finished training with best train loss: {train_l2:.6f} and validation loss: {valid_l2:.6f}")

    def forward(self, x):
        return self.base_model(x)

    def parameters(self):
        return self.base_model.parameters()




class ProbHardE2E(nn.Module):
    def __init__(self,
                 base_model_class,
                 base_model_params,
                 probconserv=False,
                 probconserv_sampling=False,
                 nonlinear_constraint="none",
                 constraint_context=None,
                 noneq_constraint_e2e=False,
                 warm_start_epoch = 0):
        super().__init__()
        self.base_model_class = base_model_class
        self.base_model_params = base_model_params
        self.loss_func = None
        self.base_model = None
        self.probconserv = probconserv
        self.nonlinear_constraint = nonlinear_constraint
        self.probconserv_sampling = probconserv_sampling
        self.constraint_context = constraint_context or {}
        self.noneq_constraint_e2e = noneq_constraint_e2e
        self.warm_start_epoch = warm_start_epoch

    # def _apply_constraints(self, mu, std, x):
    #     t = self.constraint_context.get("t", None)
    #     tpred = self.constraint_context.get("tpred", None)
    #     grid_train = self.constraint_context.get("grid_train", None)
    #     dataset_class = self.constraint_context.get("dataset_class", None)

    #     if self.probconserv:
    #         mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
    #         mu, std, _, _ = apply_constraint(
    #             mu=mu[..., 0],
    #             std=std[..., 0],
    #             mass_rhs_func=mass_rhs_func,
    #             t=t,
    #             tpred=tpred,
    #             grid_train=grid_train,
    #             precis_g=np.inf,
    #             second_deriv_alpha=None
    #         )
    #         return mu.unsqueeze(-1), std.square().unsqueeze(-1)
    #     else:
    #         return mu, std.square()

    def _apply_constraints(self, mu, std, x):
        t = self.constraint_context.get("t", None)
        tpred = self.constraint_context.get("tpred", None)
        grid_train = self.constraint_context.get("grid_train", None)
        dataset_class = self.constraint_context.get("dataset_class", None)

        print(self.base_model.cov_type, self.probconserv, self.nonlinear_constraint)

        if self.base_model.cov_type != "diagonal":
            raise NotImplementedError("Constraint logic currently supports only diagonal covariance.")

        if self.probconserv:
            if self.probconserv_sampling:
                n_samples = 100
                nf, nx, nt, _ = mu.shape

                dist = torch.distributions.Normal(mu, std)
                samples = dist.rsample((n_samples,))  # shape: [nb, nf, nx, nt, 1]

                samples = rearrange(samples, "nb nf nx nt 1 -> (nb nf) nx nt 1")
                rep_std = repeat(std, "nf nx nt 1 -> (nb nf) nx nt 1", nb=n_samples)
                rep_x = repeat(x, "nf nx nt 1 -> (nb nf) nx nt 1", nb=n_samples)

                mass_rhs_func = dataset_class.get_mass_rhs_func(x=rep_x)

                new_samples, _, _, _ = apply_constraint(
                    mu=samples[..., 0],
                    std=rep_std[..., 0],
                    mass_rhs_func=mass_rhs_func,
                    t=t,
                    tpred=tpred,
                    grid_train=grid_train,
                    precis_g=np.inf,
                    second_deriv_alpha=None
                )

                new_samples = rearrange(new_samples, "(nb nf) nx nt -> nb nf nx nt", nb=n_samples)
                sample_mu = new_samples.mean(dim=0)
                sample_std = new_samples.std(dim=0)

                mu = sample_mu
                std = sample_std

            else:
                mass_rhs_func = dataset_class.get_mass_rhs_func(x=x)
                mu, std, _, _ = apply_constraint(
                    mu=mu[..., 0],
                    std=std[..., 0],
                    mass_rhs_func=mass_rhs_func,
                    t=t,
                    tpred=tpred,
                    grid_train=grid_train,
                    precis_g=np.inf,
                    second_deriv_alpha=None
                )
        else:
            mu = mu[..., 0]
            std = std[..., 0]


        if self.nonlinear_constraint == "none":
            return mu.unsqueeze(-1), torch.square(std).unsqueeze(-1)

        nf, nx, nt = mu.shape
        mu = mu.reshape(nf, nx * nt)
        std = std.reshape(nf, nx * nt)

        if self.nonlinear_constraint == "sampling":
            num_samples = 5
            dist = torch.distributions.Normal(mu, std)
            y_samples = dist.rsample((num_samples,))
            y_samples = y_samples / torch.norm(y_samples, dim=-1, keepdim=True)
            mu = y_samples.mean(dim=0)
            std = y_samples.std(dim=0)

        elif self.nonlinear_constraint == "deltamethod":
            norm_mu = torch.norm(mu, dim=1, keepdim=True)
            mu = mu / norm_mu
            I = torch.eye(mu.shape[1], device=mu.device).expand(mu.shape[0], -1, -1)
            J_g = I / norm_mu.unsqueeze(-1) - (mu.unsqueeze(-1) @ mu.unsqueeze(1)) / norm_mu.unsqueeze(-1).pow(3)
            Sigma_diag = std ** 2
            J_Sigma = J_g ** 2 * Sigma_diag.unsqueeze(-2)
            var_diag = torch.sum(J_Sigma, dim=-1)
            std = torch.sqrt(torch.clamp(var_diag, min=1e-6))

        mu = mu.view(nf, nx, nt)
        std = std.view(nf, nx, nt)
        return mu.unsqueeze(-1), torch.square(std).unsqueeze(-1)


    def test(self, test_loader, **test_params):
        self.eval()
        test_l2 = 0.0
        # tpred = test_params.get("tpred", None)
        # t = test_params.get("t", None)
        # dataset_class = test_params.get("dataset_class", None)
        # grid_train = test_params.get("grid_train", None)

        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                out = self(x)

                # if self.probconserv or self.nonlinear_constraint != "none":
                #     mu, var = out
                #     std = torch.sqrt(var)
                #     out = self._apply_constraints(mu, std, x)
                #     # out = self._apply_constraints(mu, std, x, t, tpred, grid_train, dataset_class)

                mu, var = out

                nf,nx,nt,_ = mu.shape

                _mu = mu.view(nf, -1)
                _var = var.view(nf, -1)
                _m = x.view(nf, -1)

                if self.noneq_constraint_e2e:
                    # u_proj, u_var = project_and_stats(torch.relu(_mu), _var, _m, self.full_residual, max_iter=30)
                    u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, self.full_residual, max_iter=30)

                    # print(u_proj, u_var)

                    if  u_proj.isnan().any().item() or  u_var.isnan().any().item():
                        print("any NaN in new_mu?", u_proj.isnan().any().item())
                        # print("min new_var before clamp:", u_var.min().item())
                        print("any NaN in new_var before clamp?", u_var.isnan().any().item())
                        # new_var = new_var.clamp(min=eps)
                        # print("min new_var after clamp:", new_var.min().item())

                    out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))

                    # if not self.update_sigma:
                    #     out = (u_proj.view(nf,nx,nt,1), var.view(nf,nx,nt,1))
                    # else:
                    #     out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))


                test_l2 += self.loss_func(out, y).item()

        test_l2 /= len(test_loader.dataset)
        return {"loss": test_l2}
    

    def conservation_residual_all_times(self, u_flat, m_flat):
        nx = len(self.base_model.grid_train)
        nt = len(self.base_model.t)
        dx = self.base_model.grid_train[1] - self.base_model.grid_train[0]
        t_grid = self.base_model.t[slice(*self.base_model.tpred)]
        nt = len(t_grid)
        dt = t_grid[1] - t_grid[0]
        u = u_flat.view(nx, nt)
        m = m_flat.view(nx, nt)

        def trapz_space(u):
            weight = torch.ones_like(u)
            weight[0, ...] *= 0.5
            weight[-1, ...] *= 0.5
            return torch.sum(weight * u, dim=0) * dx

        # Initial mass
        mass_0 = trapz_space(u[:, 0])  # scalar

        # Compute u^m * u_x at both ends using backward differences
        um = u ** m
        ux_left = (u[0, :] - u[1, :]) / dx
        ux_right = (u[-2, :] - u[-1, :]) / dx

        flux_diff = um[-1, :] * ux_right - um[0, :] * ux_left  # shape: (nt,)

        # Left Riemann sum: net flux up to each time step
        flux_increments = torch.cat([
            torch.zeros(1, device=u.device),        # net_flux[0] = 0
            flux_diff[:-1] * dt
        ], dim=0)

        net_flux = torch.cumsum(flux_increments, dim=0)

        # Mass at each time
        mass_t = trapz_space(u)

        # Residual = mass_t - mass_0 + net_flux (should be ≈ 0)
        residue = mass_t - mass_0 + net_flux

        return residue[1:]  # skip t=0


    def ic_residual(self, u_flat, u0):
        nx = len(self.base_model.grid_train)
        t_grid = self.base_model.t[slice(*self.base_model.tpred)]
        nt = len(t_grid)
        u = u_flat.view(nx, nt)           # shape: (nx, nt)
        return u[1:-1, 0]              # shape: (nx,)


    # t_grid = t[slice(*tpred)].clone().to(device)
    def bc_residual_dirichlet(self, u_flat, m_flat):
        nx = len(self.base_model.grid_train)
        t_grid = self.base_model.t[slice(*self.base_model.tpred)]
        nt = len(t_grid)
        self.constraint_context.get("tpred", None)
        u = u_flat.view(nx, nt)
        m = m_flat[0] # for broadcasting

        # Construct target boundary profile for left boundary
        # t_grid = torch.linspace(0, 1, nt, device=u.device)  # shape: (nt,)
        left_bc_target = (m * t_grid) ** (1.0 / m)  # shape: (nt,)
        
        left_bc_actual = u[0, :]                               # u(x=0, t)
        right_bc_actual = u[-1, :]                             # u(x=1, t)

        h_left = left_bc_actual - left_bc_target               # (nt,)
        h_right = right_bc_actual                              # (nt,)

        return torch.cat([h_left, h_right], dim=0)             # shape: (2 × nt,)

    # Combine
    def full_residual(self, u_flat, m_flat):
        return torch.cat([self.conservation_residual_all_times(u_flat, m_flat),
                        self.ic_residual(u_flat, m_flat),                          # (nx,)
                        self.bc_residual_dirichlet(u_flat, m_flat),
                        ])

    def fit(self, train_loader, valid_loader, **fit_params):
        self.base_model = self.base_model_class(
                output_var=True,
                probconserv=False,  # Disable internal constraints
                **self.base_model_params,
                **self.constraint_context
            ).to(device)

        self.loss_func = self.base_model.loss_func
        
        lr = fit_params.get("lr", 1e-3)
        step_size = fit_params.get("step_size", 50)
        gamma = fit_params.get("gamma", 0.5)
        epochs = fit_params.get("epochs", 200)
        # tpred = fit_params.get("tpred", None)
        # t = fit_params.get("t", None)
        # dataset_class = fit_params.get("dataset_class", None)
        # grid_train = fit_params.get("grid_train", None)

        optimizer = torch.optim.Adam(self.base_model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        best_valid_l2 = np.inf

        self.update_sigma = False

        for epoch in range(epochs):
            self.train()
            train_l2 = 0
            for batch in train_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = self.base_model(x)

                mu, var = out

                nf,nx,nt,_ = mu.shape

                _mu = mu.view(nf, -1)
                _var = var.view(nf, -1)
                _m = x.view(nf, -1)

                # print(_m)

                if self.noneq_constraint_e2e:

                    # u_proj, u_var = project_and_stats(torch.relu(_mu), _var, _m, self.full_residual, max_iter=30)
                    u_proj, u_var = project_and_stats_orth(torch.relu(_mu), _var, _m, self.full_residual, max_iter=30)

                    # print(u_proj, u_var)

                    if  u_proj.isnan().any().item() or  u_var.isnan().any().item():
                        print("any NaN in new_mu?", u_proj.isnan().any().item())
                        # print("min new_var before clamp:", u_var.min().item())
                        print("any NaN in new_var before clamp?", u_var.isnan().any().item())
                        # new_var = new_var.clamp(min=eps)
                        # print("min new_var after clamp:", new_var.min().item())

                    # out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))

                    if epoch < self.warm_start_epoch:
                        out = (u_proj.view(nf,nx,nt,1), var.view(nf,nx,nt,1))
                    else:
                        if not self.update_sigma:
                            print("Updating Sigma now")
                        self.update_sigma = True
                        out = (u_proj.view(nf,nx,nt,1), u_var.view(nf,nx,nt,1))


                # if self.probconserv or self.nonlinear_constraint != "none":
                #     mu, var = out
                #     std = torch.sqrt(var)
                #     # out = (mu, torch.square(std))
                #     out = self._apply_constraints(mu, std, x)
                # #out = self._apply_constraints(mu, std, x, self.t, self.tpred, self.grid_train, self.dataset_class)


                l2 = self.loss_func(out, y)
                l2.backward()
                optimizer.step()
                train_l2 += l2.item()

            train_l2 /= len(train_loader.dataset)
            scheduler.step()
            if valid_loader is not None:
                valid_l2 = self.test(valid_loader, **fit_params)["loss"]
            else:
                valid_l2 = train_l2

            saved = ""
            if valid_l2 < best_valid_l2:
                best_valid_l2 = valid_l2
                best_model_state_dict = deepcopy(self.state_dict())
                saved = "(saved)"

            print(f"Epoch {epoch}: Train loss={train_l2:.6f}, Validation loss={valid_l2:.6f} {saved}")

        self.load_state_dict(best_model_state_dict)
        train_l2 = self.test(train_loader, **fit_params)["loss"]
        if valid_loader is not None:
            valid_l2 = self.test(valid_loader, **fit_params)["loss"]
        else:
            valid_l2 = train_l2
        print(f"Finished training with best train loss: {train_l2:.6f} and validation loss: {valid_l2:.6f}")

    def forward(self, x):
        return self.base_model(x)

    def parameters(self):
        return self.base_model.parameters()



if __name__ == '__main__':
    pass
    # # Usage
    # from models.FNO2d import FNO2d
    # x_train = torch.rand(1, 100, 20, 1)
    # y_train = torch.rand(1, 100, 20, 1)
    # train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), 
    #                                             batch_size=20, shuffle=True)

    # FNO2d_params = {"modes1": 12, "modes2": 12, "width": 32}
    # uq_model = BayesianNO(FNO2d, FNO2d_params)
    # uq_model = OutputVar(FNO2d, FNO2d_params)
    # uq_model = EnsembleNO(FNO2d, FNO2d_params, n_models=10)
    # uq_model = MCDropoutNO(FNO2d, FNO2d_params, dropout=0.1, n_dropouts=30)

    # uq_model.fit(train_loader, train_loader, epochs=10)
    # mu, var = uq_model(x_train.to(device))
    # print(mu.shape, var.shape)
    # results = uq_model.test(train_loader)
    # print(results)

