# models/model.py

import random
import warnings

import numpy as np
import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
import time
from torch.func import jacfwd, jacrev, vmap

from configs.config import (
    EPOCH_START_HARD_CONSTRAINED,
    HIDDEN_LAYERS,
    HIDDEN_NEURONS,
    INFERENCE_TOLERANCE,
    INPUT_NEURONS,
    MAX_IT,
    OUTPUT_NEURONS,
    SOFT_CONSTRAINED,
    SUPERVISED,
    TRAINING_TOLERANCE,
    VERBOSE,
    WEIGHT_LOSS_DISPLACEMENT,
    WEIGHT_LOSS_SOFT,
)


class ENFORCE(nn.Module):
    def __init__(
        self,
        scaling_input,
        scaling_output,
        c,
        constrained=True,
        weighting_option=5,
        random_seed=42,
        ssl_loss=None,
        jac=None,
    ):
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        random.seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed_all(random_seed)
        super(ENFORCE, self).__init__()
        self.input_layer = nn.Linear(INPUT_NEURONS, HIDDEN_NEURONS)
        self.hidden_layer = nn.Linear(HIDDEN_NEURONS, HIDDEN_NEURONS)
        self.hidden_activation = nn.ReLU()
        self.output_layer = nn.Linear(HIDDEN_NEURONS, OUTPUT_NEURONS)
        self.loss_function = nn.MSELoss()
        self.ssl_loss = ssl_loss
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.losses = []
        self.jac = jac

        self.ni = INPUT_NEURONS
        self.no = OUTPUT_NEURONS

        # self.mean_input = scaling_input[0]
        # self.mean_output = scaling_output[0]
        # self.std_input = scaling_input[1]
        # self.std_output = scaling_output[1]
        self.mean_input = torch.as_tensor(
            scaling_input[0], device=self.device, dtype=torch.float32
        )
        self.std_input = torch.as_tensor(
            scaling_input[1], device=self.device, dtype=torch.float32
        )
        self.mean_output = torch.as_tensor(
            scaling_output[0], device=self.device, dtype=torch.float32
        )
        self.std_output = torch.as_tensor(
            scaling_output[1], device=self.device, dtype=torch.float32
        )
        self._c = c
        self.constrained = constrained
        self.weighting_option = weighting_option

        self.epoch = 1
        self.training_iter = 0
        self.start_projection = False
        self._epochprint = None

    def check_system(self):
        # Issue a warning if the system is determined
        if self.nc == self.no:
            warnings.warn(
                "The system is determined. Maybe you already know your underlying model!"
            )

        # Ensure the number of constraints is not greater than the number of outputs
        try:
            assert self.nc <= self.no
        except AssertionError:
            raise ValueError("Too many constraints!")

    def c(self, x, y):
        ci = self._c(x, y)
        if isinstance(ci, tuple):
            self.nc = len(ci)
            self.check_system()
            return torch.stack(ci, dim=1)
        elif isinstance(ci, torch.Tensor) and ci.ndim == 1:
            self.nc = 1  # Since c returns a tensor of shape [BS]
            self.check_system()
            return ci.unsqueeze(1)
        elif (
            isinstance(ci, torch.Tensor) and ci.ndim > 1
        ):  # return a tensor of shape [BS, NC]
            self.nc = ci.shape[1]  # Number of constraints
            self.check_system()
            return ci

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_activation(x)
        for i in range(HIDDEN_LAYERS):
            x = self.hidden_layer(x)
            x = self.hidden_activation(x)
        x = self.output_layer(x)
        return x

    def loss(self, x, y):
        # t = time.time()
        ytilde, yhat, proj_iter = self.predict(x, y, training=True)
        # print(f"Predict time: {time.time()-t:.6f} seconds")
        loss_data_before_projection = self.loss_function(y, yhat)
        loss_data_after_projection = self.loss_function(y, ytilde)
        loss_displacement = torch.mean(
            (yhat - ytilde) ** 2
        )  # this is zero if the projection is not done
        loss = WEIGHT_LOSS_DISPLACEMENT * loss_displacement

        x_unscaled, ytilde_unscaled = self.unscale(x, ytilde)
        if SOFT_CONSTRAINED:
            c = self.c(x=x_unscaled, y=ytilde_unscaled)
            loss += WEIGHT_LOSS_SOFT * torch.mean(torch.abs(c))

        if SUPERVISED:
            loss += loss_data_after_projection

        else:
            loss_ssl = self.ssl_loss(x_unscaled, ytilde_unscaled)
            loss += loss_ssl

        return (
            loss,
            loss_data_after_projection,
            loss_displacement,
            loss_data_before_projection,
            ytilde,
            yhat,
            proj_iter,
        )

    def compute_residual(self, c, tolerance_mode):
        if tolerance_mode == "mean":
            tolerance_value = torch.mean(torch.abs(c))
        elif tolerance_mode == "max":
            tolerance_value = torch.max(torch.abs(c))
        else:
            raise ValueError("Invalid tolerance mode. Choose 'mean' or 'max'.")
        return tolerance_value

    def ada_np(
        self,
        x,
        y,
        tolerance_mode="mean",
        tolerance_value=TRAINING_TOLERANCE,
        max_iter=MAX_IT,
    ):
        # cache

        proj_iter = 1
        c = torch.zeros(x.shape[0], self.nc, device=x.device, dtype=x.dtype)
        input_unscaled, output_unscaled = self.unscale(x, y)
        c = self.c(x=input_unscaled, y=output_unscaled)  # Shape: [BS, NC]
        c_res = self.compute_residual(c, tolerance_mode)
        while c_res > tolerance_value and proj_iter < max_iter:
            self.compute_dc_dy(x, y)
            y = self.project(x, y)
            proj_iter += 1
            input_unscaled, output_unscaled = self.unscale(x, y)
            c = self.c(x=input_unscaled, y=output_unscaled)  # Shape: [BS, NC]
            c_res = self.compute_residual(c, tolerance_mode)
        if proj_iter == MAX_IT:
            print(f"Max projection iteration reached ({MAX_IT})")
        return y, proj_iter

    def predict(self, x, y=None, training=False):
        # tfdc = time.time()
        proj_iter = 0
        if self.constrained and self.epoch >= EPOCH_START_HARD_CONSTRAINED:
            yhat = self.forward(x)
            self.compute_dc_dy(x, yhat)
            # print(f"Compute forward and dc_dy time: {time.time()-tfdc:.6f} seconds")
            # t = time.time()
            ytilde = self.project(x, yhat)
            # print(f"Projection time: {time.time()-t:.6f} seconds")
            # t_ap = time.time()
            proj_iter += 1
            if training:
                if SUPERVISED:
                    loss_data_before_projection = self.loss_function(y, yhat)
                    loss_data_after_projection = self.loss_function(y, ytilde)

                    # Criteria to project (note the first projection is already done)
                    if loss_data_before_projection < loss_data_after_projection:
                        # If the projection does not improve the prediction, then it is not considered
                        ytilde = yhat
                        if self.start_projection:
                            self.start_projection = False
                            print(
                                f"Stop projection: Epoch {self.epoch} Iteration: {self.training_iter} Loss before: {loss_data_before_projection}"
                            )
                    else:
                        if not self.start_projection:
                            print(
                                f"Start projection: Epoch {self.epoch} Iteration: {self.training_iter} Loss before: {loss_data_before_projection} Loss after: {loss_data_after_projection}"
                            )
                        ytilde, proj_iter = self.ada_np(
                            x,
                            ytilde,
                            tolerance_mode="mean",
                            tolerance_value=TRAINING_TOLERANCE,
                            max_iter=MAX_IT,
                        )
                        self.start_projection = True
                else:
                    if VERBOSE:
                        x_unscaled, yhat_unscaled = self.unscale(x, yhat)
                        _, ytilde_unscaled = self.unscale(x, ytilde)
                        ssl_loss_hat = self.ssl_loss(x_unscaled, yhat_unscaled)
                        ssl_loss_tilde = self.ssl_loss(x_unscaled, ytilde_unscaled)

                        if self.epoch % 1 == 0 and self.epoch != self._epochprint:
                            avg_pred_obj_value = torch.mean(ssl_loss_tilde)
                            x_u, y_u = self.unscale(x, y)
                            avg_computed_obj_value = torch.mean(self.ssl_loss(x_u, y_u))
                            print(
                                f"Epoch {self.epoch} Average computed obj value: {avg_computed_obj_value:.3f} Average predicted obj value: {avg_pred_obj_value:.3f}"
                            )
                            print(
                                f"Difference objective: {(avg_pred_obj_value - avg_computed_obj_value):.3f}"
                            )
                            print(
                                f"Relative difference objective: {((avg_pred_obj_value - avg_computed_obj_value) / avg_computed_obj_value * 100):.3f}%"
                            )
                            self._epochprint = self.epoch
                        if SOFT_CONSTRAINED:
                            c_hat = self.c(x=x_unscaled, y=yhat_unscaled)
                            c_tilde = self.c(x=x_unscaled, y=ytilde_unscaled)
                            if self.epoch % 1 == 0 and self.epoch != self._epochprint:
                                print(
                                    f"Epoch {self.epoch} Soft constraint loss before projection: {torch.mean(torch.abs(c_hat)):.3f} after projection: {torch.mean(torch.abs(c_tilde)):.3f}"
                                )
                                self._epochprint = self.epoch
                            ssl_loss_hat += WEIGHT_LOSS_SOFT * torch.mean(
                                torch.abs(c_hat)
                            )
                            ssl_loss_tilde += WEIGHT_LOSS_SOFT * torch.mean(
                                torch.abs(c_tilde)
                            )
                        if ssl_loss_hat < ssl_loss_tilde:
                            ytilde = yhat
                            if self.start_projection:
                                self.start_projection = False
                                print(
                                    f"Stop projection: Epoch {self.epoch} Iteration: {self.training_iter} Loss before: {ssl_loss_hat}"
                                )
                        elif (
                            ssl_loss_hat > ssl_loss_tilde
                            and self.compute_residual(c_tilde, tolerance_mode="mean")
                            > TRAINING_TOLERANCE
                        ):
                            if not self.start_projection:
                                print(
                                    f"Start projection: Epoch {self.epoch} Iteration: {self.training_iter} Loss before: {ssl_loss_hat} Loss after: {ssl_loss_tilde}"
                                )
                            # t_adanp = time.time()
                            ytilde, proj_iter = self.ada_np(
                                x,
                                ytilde,
                                tolerance_mode="mean",
                                tolerance_value=TRAINING_TOLERANCE,
                                max_iter=MAX_IT,
                            )
                            # print(f"AdaNP time: {time.time()-t_adanp:.6f} seconds")
                            self.start_projection = True
                    else:
                        # t_adanp = time.time()
                        ytilde, proj_iter = self.ada_np(
                            x,
                            ytilde,
                            tolerance_mode="mean",
                            tolerance_value=TRAINING_TOLERANCE,
                            max_iter=MAX_IT,
                        )
                        # print(f"AdaNP time: {time.time()-t_adanp:.6f} seconds")
                        self.start_projection = True

            else:  # inference
                ytilde, proj_iter = self.ada_np(
                    x,
                    ytilde,
                    tolerance_mode="max",
                    tolerance_value=INFERENCE_TOLERANCE,
                    max_iter=MAX_IT,
                )
                print(f"Projection iterations inference: {proj_iter}")

        else:
            yhat = self.forward(x)
            ytilde = yhat
            if training:
                self.training_iter += 1
        # print(f"Predict (after projection) time: {time.time()-t_ap:.6f} seconds\n\n")
        return ytilde, yhat, proj_iter

    def unscale(self, x_scaled, y_scaled):
        with profiler.record_function("unscale"):
            x = x_scaled * self.std_input + self.mean_input
            if (self.std_output >= 1e-5).all():
                y = y_scaled * self.std_output + self.mean_output
            else:
                # Avoid division by zero if std is zero
                y = y_scaled + self.mean_output
        return x, y

    def scale(self, x_unscaled, y_unscaled):
        x = (x_unscaled - self.mean_input) / self.std_input
        if (self.std_output >= 1e-5).all():
            y = (y_unscaled - self.mean_output) / self.std_output
        else:
            # Avoid division by zero if std is zero
            y = y_unscaled - self.mean_output
        # y = (y_unscaled - self.mean_output) / self.std_output
        return x, y

    def compute_dc_dy(self, input, output):
        with profiler.record_function("compute_dc_dy"):
            if self.jac:
                input_unscaled, output_unscaled = self.unscale(input, output)
                self.bs = input_unscaled.shape[0]
                self.nc = input_unscaled.shape[1]
                self.dc_dy = self.jac(output_unscaled)
                return self.dc_dy
            else:
                input_unscaled, output_unscaled = self.unscale(input, output)
                c = self.c(x=input_unscaled, y=output_unscaled)  # Shape: [BS, NC]

                self.bs = c.size(0)
                self.nc = c.size(1)

                # Initialize the Jacobian tensor: [BS, num_constraints, num_outputs]
                dc_dy = torch.zeros(
                    self.bs,
                    self.nc,
                    self.no,
                    dtype=output_unscaled.dtype,
                    device=output_unscaled.device,
                )

                for i in range(self.nc):  # Loop over constraint components
                    # Create grad_outputs tensor to select c_i
                    grad_outputs = torch.zeros_like(c)  # Shape: [BS, NC]
                    grad_outputs[:, i] = 1  # Set grad_outputs for c_i

                    # Compute gradients: grad_c_i_wrt_y = [BS, NO]

                    grad_c_i_wrt_y = torch.autograd.grad(
                        outputs=c,
                        inputs=output_unscaled,
                        grad_outputs=grad_outputs,
                        create_graph=False,
                        retain_graph=True,
                    )[0]  # Shape: [BS, NO]

                    # Handle the case where grad might be None
                    if grad_c_i_wrt_y is None:
                        grad_c_i_wrt_y = torch.zeros_like(output_unscaled)

                    # Store gradients in the Jacobian tensor
                    dc_dy[:, i, :] = grad_c_i_wrt_y  # Shape: [BS, NO]

                self.dc_dy = dc_dy  # Shape: [BS, NC, NO]
                return self.dc_dy

    # def compute_dc_dy(self, input, output,
    #               nc_loop_threshold: int = 5,   # tweak if you like
    #               bs_loop_threshold: int = 2000):
    #     with profiler.record_function("compute_dc_dy"):
    #         # ---------- 1. unscale ----------------------------------------------------
    #         x, y = self.unscale(input, output)          # x:[BS,…], y:[BS,NO]
    #         bs, no = y.shape
    #         nc = self.c(x[:1], y[:1]).shape[-1]         # cheap probe (no grad)

    #         # ---------- 2. choose implementation -------------------------------------
    #         if nc <= nc_loop_threshold and bs <= bs_loop_threshold:
    #             # ---- Python loop ----------------------------------------------------
    #             c = self.c(x, y)                        # [BS, NC]
    #             dc_dy = torch.zeros(bs, nc, no, dtype=y.dtype, device=y.device)

    #             for i in range(nc):
    #                 grad_out = torch.zeros_like(c)
    #                 grad_out[:, i] = 1
    #                 grad = torch.autograd.grad(
    #                         c, y, grad_out,
    #                         retain_graph=False,
    #                         create_graph=False)[0]     # [BS, NO]
    #                 dc_dy[:, i, :] = grad

    #         else:
    #             # ---- functorch path --------------------------------------------------
    #             # pick cheaper AD direction: forward if NO ≥ NC
    #             jac  = jacfwd if no >= nc else jacrev

    #             def per_sample(x_i, y_i):               # y_i: [NO] → [NC, NO]
    #                 def c_wrap(y_single):
    #                     return self.c(x_i[None], y_single[None])[0]
    #                 return jac(c_wrap)(y_i)

    #             dc_dy = vmap(per_sample)(x, y)          # [BS, NC, NO]

    #         # ---------- 3. stash & return --------------------------------------------
    #         self.bs, self.nc, self.dc_dy = bs, nc, dc_dy
    #         return dc_dy

    def B_f(self):
        B = self.dc_dy  # Shape: [BS, NC, NO]
        return B

    def vi_f(self, input_unscaled, output_unscaled):
        # one fused kernel: for each batch b and constraint i,
        # sum_k  dc_dy[b,i,k] * output_unscaled[b,k]
        vi_lin = torch.einsum("bik,bk->bi", self.dc_dy, output_unscaled)  # [BS,NC]
        vi = vi_lin - self.c(input_unscaled, output_unscaled)  # elementwise subtract
        self.vi = vi
        return vi

    def Wi_f(self):
        # shapes and dtypes
        bs, nc, no = self.dc_dy.shape  # NC == NO
        device = self.dc_dy.device
        dtype = self.dc_dy.dtype

        if self.weighting_option == 1:
            return None
            # Wi = I_{NO} for each batch
            # → allocate zeros and write 1s on the diag only (BS*NO writes)
            # Wi = torch.zeros(bs, no, no, device=device, dtype=dtype)
            # Wi.diagonal(dim1=1, dim2=2).fill_(1.0)
            # return Wi

        elif self.weighting_option == 6:
            # random in [0.1, 1.0):
            w = 0.9 * torch.rand(bs, no, device=device, dtype=dtype) + 0.1  # [BS, NO]
            Wi = torch.zeros(bs, no, no, device=device, dtype=dtype)
            Wi.diagonal(dim1=1, dim2=2).copy_(w)  # copy only BS*NO elements
            return Wi

        elif self.weighting_option == 5:
            # instance‐dependent
            # 1) mean abs derivative per constraint: [BS, NO]
            d = self.dc_dy.abs().mean(dim=1)

            # 2) replace zeros with 1.0 in-place
            d = d.clamp_min(1.0)

            # 3) invert & normalize by the per‐batch min: [BS, NO]
            inv = 1.0 / d
            inv = inv / inv.min(dim=1, keepdim=True).values

            # 4) fill diag
            Wi = torch.zeros(bs, no, no, device=device, dtype=dtype)
            Wi.diagonal(dim1=1, dim2=2).copy_(inv)
            return Wi

        elif self.weighting_option == 3:
            # batch‐averaged
            m = self.dc_dy.abs().mean(dim=(0, 1))  # [NO]
            m = m.clamp_min(1.0)  # replace zero
            inv = 1.0 / m  # [NO]
            inv = inv / inv.min()  # normalize
            # build single [NO, NO] diag matrix via broadcast‐mul
            I = torch.eye(no, device=device, dtype=dtype)
            Wi = I * inv  # broadcast inv over diag only
            return Wi

        else:
            raise ValueError(f"Unknown weighting_option: {self.weighting_option}")

    def projection_tensors(self, B, v, W_inv=None):
        with profiler.record_function("build_projection_tensors"):
            if self.weighting_option == 1:
                # Just compute BWB_T = B @ B^T
                # 1. make B^T contiguous once (avoids per-slice copy in bmm)
                B_T = B.transpose(1, 2).contiguous()
                # 2. compute BWB^T via bmm
                BWB_T = torch.bmm(B, B_T)  # [BS, NO, NO]
                # 3. invert BWB_T via cholesky (quite fast 0.005998 seconds)
                # tinv = time.time()
                ch = torch.linalg.cholesky(BWB_T)  # [BS, NO, NO]
                mid_inv = torch.cholesky_inverse(ch)
                # print(f"Cholesky time: {time.time()-tinv:.6f} seconds")
                # 4. form M = B^T @ (BWB_T)^-1
                #    B_T assumed contiguous from caller
                # t = time.time()
                M = torch.bmm(B_T, mid_inv)  # [BS, NO, NO]
                MB = torch.bmm(M, B)
                # print(f"Form MB time: {time.time()-t:.6f} seconds")

            else:
                # 1. make B^T contiguous once (avoids per-slice copy in bmm)
                B_T = B.transpose(1, 2).contiguous()  # [BS, NO, NC]

                # 2. compute BWB^T via bmm
                BWB = torch.bmm(B, W_inv)  # [BS, NC, NO]
                eps = 1e-8
                I_nc = torch.eye(self.nc, device=B.device)  # [NC, NC]
                # baddbmm: for each batch: eps*I_nc + BWB[i] @ B_T[i]
                BWB_T = torch.baddbmm(eps * I_nc, BWB, B_T)  # [BS, NC, NC], contiguous

                # 3. cholesky + inverse
                ch = torch.linalg.cholesky(BWB_T)  # [BS, NC, NC]
                mid_inv = torch.cholesky_inverse(ch)  # [BS, NC, NC]

                # 4. form M = W_inv @ B^T @ (BWB_T)^-1
                #    W_inv assumed contiguous from caller
                W_inv_B_T = torch.bmm(W_inv, B_T)  # [BS, NO, NC]
                M = torch.bmm(W_inv_B_T, mid_inv)  # [BS, NO, NC]

            # # 3) cholesky factor
            # ch = torch.linalg.cholesky(BWB_T)             # [BS, NC, NC]

            # # 4) solve for M without ever inverting explicitly
            # #    we want X such that BWB_T @ X = (W_inv @ B^T)^T
            # Winv_B_T = torch.bmm(W_inv, B_T)              # [BS, NO, NC]
            # b         = Winv_B_T.transpose(1, 2).contiguous()  # [BS, NC, NO]
            # X         = torch.cholesky_solve(b, ch)       # [BS, NC, NO]
            # M         = X.transpose(1, 2).contiguous()    # [BS, NO, NC]

            # 5. compute B_star = I – M @ B without full identity repeat
            MB = torch.bmm(M, B)  # [BS, NO, NO]
            B_star = -MB  # [BS, NO, NO]
            # add 1 to each diagonal element in-place (only BS*NO writes)
            B_star.diagonal(dim1=1, dim2=2).add_(1.0)

            # 6. compute v_star
            v_in = v.unsqueeze(2).contiguous()  # [BS, NO, 1]
            v_star = torch.bmm(M, v_in).squeeze(2)  # [BS, NO]

            return B_star, v_star

    def project(self, input, output):
        # 1) unscale once
        input_unscaled, output_unscaled = self.unscale(input, output)

        # 2) build (and invert) Wi on correct device
        self.Wi = self.Wi_f()  # should already be on self.device
        if self.weighting_option == 1:
            W_inv = None
        else:
            Wi_inv = torch.inverse(self.Wi)

            # 3) turn any 2D Wi_inv into a true [BS,NO,NO] tensor
            if Wi_inv.dim() <= 2:
                # unsqueeze+repeat gives a contiguous [BS,NO,NO]
                W_inv = Wi_inv.unsqueeze(0).repeat(self.bs, 1, 1)
            else:
                # if it’s already [BS,NO,NO], just make it contiguous once
                W_inv = Wi_inv.contiguous()

        # 4) grab B and v, and force contiguity up-front
        B = self.B_f().contiguous()  # [BS,NC,NO]
        v = self.vi_f(input_unscaled, output_unscaled).contiguous()  # [BS,NO]

        # 5) do the heavy lifting—our optimized routine already returns
        #    contiguous WB_star, Wv_star
        # t = time.time()
        WB_star, Wv_star = self.projection_tensors(B, v, W_inv=W_inv)
        # print(f"Projection tensors time: {time.time()-t:.6f} seconds")

        # 6) apply projection: make output_unscaled[:, :, None] contiguous once
        y_in = output_unscaled.unsqueeze(2).contiguous()  # [BS,NO,1]
        y_proj = torch.bmm(WB_star, y_in).squeeze(2) + Wv_star

        # 7) re-scale and return
        _, y_proj_scaled = self.scale(input_unscaled, y_proj)
        return y_proj_scaled
