# code modified from https://github.com/Nicolas-Pinon/uad_ocsvm_guided_repr_learning/blob/main/models_xp1/models_coupled_torch.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from cvxpylayers.torch import CvxpyLayer
import cvxpy as cp


class OCSVMguidedAutoencoder(nn.Module):
    def __init__(self, batch_size_train, batch_size_valid, input_dim, latent_dim=32, ocsvm_coeff=0.1,
                 nu_ocsvm_coeff=0.03, gamma_rbf_coeff="scale", jz_mode="StopGradLoss",
                 jean_zad_linear=False):
        super().__init__()
        self.latent_dim = latent_dim
        self.ocsvm_coeff = ocsvm_coeff
        self.nu = nu_ocsvm_coeff
        self.gamma_mode = gamma_rbf_coeff
        self.jz_mode = jz_mode
        self.linear = jean_zad_linear

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 8),
            nn.LeakyReLU(),
            nn.Linear(8, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 8),
            nn.LeakyReLU(),
            nn.Linear(8, 16),
            nn.LeakyReLU(),
            nn.Linear(16, input_dim),
            nn.Sigmoid()
        )

        # CVXPY problem for training
        n = batch_size_train // 2
        alpha = cp.Variable(n)
        k_sqrt = cp.Parameter((n, n), PSD=True)
        constraints = [cp.sum(alpha) == self.nu * n, alpha >= 0, alpha <= 1]
        objective = cp.Minimize(0.5 * cp.sum_squares(k_sqrt @ alpha))
        prob = cp.Problem(objective, constraints)
        self.ocsvm_layer_train = CvxpyLayer(prob, parameters=[k_sqrt], variables=[alpha])

        # CVXPY problem for validation
        n_val = batch_size_valid // 2
        alpha_val = cp.Variable(n_val)
        k_sqrt_val = cp.Parameter((n_val, n_val), PSD=True)
        constraints_val = [cp.sum(alpha_val) == self.nu * n_val, alpha_val >= 0, alpha_val <= 1]
        objective_val = cp.Minimize(0.5 * cp.sum_squares(k_sqrt_val @ alpha_val))
        prob_val = cp.Problem(objective_val, constraints_val)
        self.ocsvm_layer_valid = CvxpyLayer(prob_val, parameters=[k_sqrt_val], variables=[alpha_val])

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z

    def solve_ocsvm(self, z, training=True):
        if torch.cuda.is_available():
            device = 'cuda:0'
        else : 
            device = 'cpu'
        n = z.shape[0] //2
        z_sv = z[:n]
        z_sv = standardize(z_sv)
        
        if self.gamma_mode == "scale":
            var = z_sv.var(dim=0, unbiased=False).mean()
            gamma = 1 / (self.latent_dim * var + 1e-6)
        elif self.gamma_mode == "auto":
            gamma = 1 / (self.latent_dim * 4)
        else:
            gamma = float(self.gamma_mode)

        if self.linear:
            K = torch.matmul(z_sv, z_sv.t())
        else:
            dist_sq = torch.cdist(z_sv, z_sv, p=2).pow(2)
            K = torch.exp(-gamma * dist_sq)

        eps = 1e-8 / gamma if not self.linear else 1e-8
        try :
            K_sqrt = torch.linalg.cholesky(K + eps * torch.eye(n, device=device))
        except:
            print('linalg.cholesky: The factorization could not be completed because the input is not positive-definite.')
            try :
                K_sqrt = torch.linalg.cholesky(torch.add(K, torch.eye(n, device=device),  alpha=1e-12) + eps * torch.eye(n, device=device))
            except:
                K_sqrt = torch.linalg.cholesky(torch.add(K, torch.eye(n, device=device),  alpha=1e-5) + eps * torch.eye(n, device=device))
                print('linalg.cholesky: The factorization could not be completed because the input is not positive-definite. * 2')
            
        layer = self.ocsvm_layer_train if training else self.ocsvm_layer_valid
        alpha, = layer(K_sqrt.double())
        alpha = alpha.float() / (self.nu * n)
        return alpha, K

    def ocsvm_objective(self, alpha, z, K_sv):
        n = z.shape[0] // 2
        z_sv, z_loss = torch.chunk(z, 2, dim=0)
        z_sv, z_loss = standardize(z_sv), standardize(z_loss)

        if self.gamma_mode == "scale":
            var = z_sv.var(dim=0, unbiased=False).mean()
            gamma = 1 / (self.latent_dim * var + 1e-6)
        elif self.gamma_mode == "auto":
            gamma = 1 / (self.latent_dim * 4)
        else:
            gamma = float(self.gamma_mode)

        if self.linear:
            K_sv_loss = torch.matmul(z_sv, z_loss.t())
        else:
            if "StopGradSV" in self.jz_mode:
                dists = torch.cdist(z_sv.detach(), z_loss, p=2).pow(2)
            elif "StopGradLoss" in self.jz_mode:
                dists = torch.cdist(z_sv, z_loss.detach(), p=2).pow(2)
            else:
                dists = torch.cdist(z_sv, z_loss, p=2).pow(2)
            K_sv_loss = torch.exp(-gamma * dists)

        sv_mask = ((alpha - 1 / (self.nu * n))**2 < (1 / (self.nu * n) - 1e-6)**2).float()
        rho = torch.sum(alpha.view(1, -1) @ K_sv @ sv_mask.view(-1, 1)) / (sv_mask.sum() + 1e-6)

        decision = (alpha.view(1, -1) @ K_sv_loss - rho) * self.nu * n
        if "StopGradSV" in self.jz_mode:
            decision = decision.detach()
        elif "StopGradLoss" in self.jz_mode:
            decision = decision

        return (F.relu(-decision))

        # obj = (1 / self.nu) * F.relu(-decision).sum()
        # return obj.squeeze()

    def compute_loss(self, x, training=True):
        x_hat, z = self.forward(x)
        mse = F.mse_loss(x_hat, x, reduction='mean')
        alpha, K_sv = self.solve_ocsvm(z, training=training)
        obj = self.ocsvm_objective(alpha, z, K_sv)
        ocsvm_obj = (1 / self.nu) *obj.sum()
        total = mse + self.ocsvm_coeff * ocsvm_obj.squeeze()
        return total, mse, ocsvm_obj
    

def standardize(x):
    return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bias=True):
        super().__init__()
        self.block = nn.Sequential(
            #nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2, bias=use_bias),
            nn.Linear(in_channels, out_channels, bias=use_bias),
            #nn.BatchNorm2d(out_channels),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(),
            # nn.MaxPool2d(2)
            nn.MaxPool1d(2)
        )

    def forward(self, x):
        return self.block(x)


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bias=True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Upsample(scale_factor=2),
            #nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, padding=2, bias=use_bias),
            nn.Linear(in_channels, out_channels),
            #nn.BatchNorm2d(out_channels),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.block(x)


