import torch
import torch.nn.functional as F

class LinearVelocity:
        def __init__(self, cfg,limit_dist, device):
                self.omega = cfg.sample.omega
                self.eta = cfg.sample.eta
                self.limit_dist = limit_dist
                self.num_classes_X = len(self.limit_dist.X)
                self.num_classes_E = len(self.limit_dist.E)
                self.device = device

        def dt_p_xt_g_x1(self, X1, E1):
                limit_dist = self.limit_dist.to_device(X1.device)
                X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float().to(X1.device)
                E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float().to(E1.device)

                dX = X1_onehot - limit_dist.X[None, None, :]
                dE = E1_onehot - limit_dist.E[None, None, None, :]
                assert (dX.sum(-1).abs() < 1e-4).all() and (dE.sum(-1).abs() < 1e-4).all()
                return dX, dE

        def p_xt_g_x1(self, X1, E1, t):
                t_time = t.squeeze(-1)[:, None, None]
                limit_dist = self.limit_dist.to_device(X1.device)

                X1_onehot = F.one_hot(X1, num_classes=len(limit_dist.X)).float().to(X1.device)
                E1_onehot = F.one_hot(E1, num_classes=len(limit_dist.E)).float().to(E1.device)

                Xt = t_time * X1_onehot + (1 - t_time) * limit_dist.X[None, None, :]
                Et = (
                        t_time[:, None] * E1_onehot
                        + (1 - t_time[:, None]) * limit_dist.E[None, None, None, :]
                )

                assert ((Xt.sum(-1) - 1).abs() < 1e-4).all()
                assert ((Et.sum(-1) - 1).abs() < 1e-4).all()

                return Xt.clamp(min=0.0, max=1.0), Et.clamp(min=0.0, max=1.0)

        def compute_pt_vals(self, t, X_t_label, E_t_label, X_1_pred, E_1_pred):
                dt_p_vals_X, dt_p_vals_E = self.dt_p_xt_g_x1(
                        X_1_pred, E_1_pred
                )

                dt_p_vals_at_Xt = dt_p_vals_X.gather(-1, X_t_label).squeeze(-1)  # (bs, n, )
                dt_p_vals_at_Et = dt_p_vals_E.gather(-1, E_t_label).squeeze(-1)  # (bs, n, n, )

                pt_vals_X, pt_vals_E = self.p_xt_g_x1(
                        X_1_pred,
                        E_1_pred,
                        t,
                )

                pt_vals_at_Xt = pt_vals_X.gather(-1, X_t_label).squeeze(-1)  # (bs, n, )
                pt_vals_at_Et = pt_vals_E.gather(-1, E_t_label).squeeze(-1)  # (bs, n, n, )

                return (
                        pt_vals_X,
                        pt_vals_E,
                        pt_vals_at_Xt,
                        pt_vals_at_Et,
                        dt_p_vals_X,
                        dt_p_vals_E,
                        dt_p_vals_at_Xt,
                        dt_p_vals_at_Et,
                )

        def compute_Rstar(
                self,
                X_1_pred,
                E_1_pred,
                X_t_label,
                E_t_label,
                pt_vals_X,
                pt_vals_E,
                pt_vals_at_Xt,
                pt_vals_at_Et,
                dt_p_vals_X,
                dt_p_vals_E,
                dt_p_vals_at_Xt,
                dt_p_vals_at_Et,
                func,
        ):
                if func == "relu":
                        inner_X = dt_p_vals_X - dt_p_vals_at_Xt[:, :, None]
                        inner_E = dt_p_vals_E - dt_p_vals_at_Et[:, :, :, None]

                        limit_dist = self.limit_dist.to_device(self.device)
                        X1_onehot = F.one_hot(X_1_pred, num_classes=len(limit_dist.X)).float().to(X_1_pred.device)
                        E1_onehot = F.one_hot(E_1_pred, num_classes=len(limit_dist.E)).float().to(E_1_pred.device)
                        mask_X = X_1_pred.unsqueeze(-1) != X_t_label
                        mask_E = E_1_pred.unsqueeze(-1) != E_t_label

                        Rstar_t_numer_X = F.relu(inner_X)  # (bs, n, dx)
                        Rstar_t_numer_E = F.relu(inner_E)  # (bs, n, n, de)

                        Rstar_t_numer_X += X1_onehot * self.omega * mask_X
                        Rstar_t_numer_E += E1_onehot * self.omega * mask_E
                else:
                        raise NotImplementedError

                Z_t_X = torch.count_nonzero(pt_vals_X, dim=-1)  # (bs, n)
                Z_t_E = torch.count_nonzero(pt_vals_E, dim=-1)  # (bs, n, n)

                # Denominator of R_t^*
                Rstar_t_denom_X = Z_t_X * pt_vals_at_Xt  # (bs, n)
                Rstar_t_denom_E = Z_t_E * pt_vals_at_Et  # (bs, n, n)
                Rstar_t_X = Rstar_t_numer_X / Rstar_t_denom_X[:, :, None]  # (bs, n, dx)
                Rstar_t_E = Rstar_t_numer_E / Rstar_t_denom_E[:, :, :, None]  # (B, n, n, de)

                Rstar_t_X = torch.nan_to_num(Rstar_t_X, nan=0.0, posinf=0.0, neginf=0.0)
                Rstar_t_E = torch.nan_to_num(Rstar_t_E, nan=0.0, posinf=0.0, neginf=0.0)

                Rstar_t_X[Rstar_t_X > 1e5] = 0.0
                Rstar_t_E[Rstar_t_E > 1e5] = 0.0

                return Rstar_t_X, Rstar_t_E

        def compute_R(
                self,
                Rstar_t_X,
                Rstar_t_E,
                Rdb_t_X,
                Rdb_t_E,
                pt_vals_at_Xt,
                pt_vals_at_Et,
                pt_vals_X,
                pt_vals_E,
        ):
                R_t_X = Rstar_t_X + Rdb_t_X
                R_t_E = Rstar_t_E + Rdb_t_E

                dx = R_t_X.shape[-1]
                de = R_t_E.shape[-1]
                R_t_X[(pt_vals_at_Xt == 0.0)[:, :, None].repeat(1, 1, dx)] = 0.0
                R_t_E[(pt_vals_at_Et == 0.0)[:, :, :, None].repeat(1, 1, 1, de)] = 0.0
                R_t_X[pt_vals_X == 0.0] = 0.0
                R_t_E[pt_vals_E == 0.0] = 0.0

                return R_t_X, R_t_E
