import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]).float(), torch.tensor(self.labels[idx]).float(), torch.tensor([0])

class NeuTraLAD(nn.Module):
    name = "NeuTraLAD"

    def __init__(
        self,
        in_features,
        fc_1_out=128,
        fc_last_out=32,
        compression_unit=16,
        n_transforms=4,
        n_layers=3,
        trans_type='mlp',
        temperature=0.07,
        trans_fc_in=None,
        trans_fc_out=None,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        super(NeuTraLAD, self).__init__()
        self.device = device
        self.in_features = in_features
        self.compression_unit = compression_unit
        self.fc_1_out = fc_1_out
        self.fc_last_out = fc_last_out
        self.n_layers = n_layers
        self.n_transforms = n_transforms
        self.temperature = temperature
        self.trans_type = trans_type
        self.trans_fc_in = trans_fc_in if trans_fc_in and trans_fc_in > 0 else self.in_features
        self.trans_fc_out = trans_fc_out if trans_fc_out and trans_fc_out > 0 else self.in_features
        self.cosim = nn.CosineSimilarity()
        self._build_network()
        self.to(self.device)

    def _create_network(self, D: int, out_dims: list, bias=True) -> list:
        net_layers = []
        previous_dim = D
        for dim in out_dims:
            net_layers.append(nn.Linear(previous_dim, dim, bias=bias))
            net_layers.append(nn.ReLU())
            previous_dim = dim
        return net_layers

    def _create_masks(self) -> list:
        masks = [None] * self.n_transforms
        out_dims = self.trans_layers
        for K_i in range(self.n_transforms):
            net_layers = self._create_network(self.in_features, out_dims, bias=False)
            net_layers[-1] = nn.Sigmoid()
            masks[K_i] = nn.Sequential(*net_layers).to(self.device)
        return masks

    def _build_network(self):
        out_dims = [0] * self.n_layers
        out_features = self.fc_1_out
        for i in range(self.n_layers - 1):
            out_dims[i] = out_features
            out_features -= self.compression_unit
        out_dims[-1] = self.fc_last_out
        self.trans_layers = [self.trans_fc_in, self.trans_fc_out]
        enc_layers = self._create_network(self.in_features, out_dims)[:-1]
        self.enc = nn.Sequential(*enc_layers).to(self.device)
        self.masks = self._create_masks()

    def _computeX_k(self, X: torch.Tensor):
        X_t_s = []
        def transform(type):
            if type == 'res':
                return lambda mask, X: mask(X) + X
            else:
                return lambda mask, X: mask(X) * X
        t_function = transform(self.trans_type)
        for k in range(self.n_transforms):
            X_t_k = t_function(self.masks[k], X)
            X_t_s.append(X_t_k)
        X_t_s = torch.stack(X_t_s, dim=0)
        return X_t_s

    def _computeH_ij(self, Z):
        hij = F.cosine_similarity(Z.unsqueeze(1), Z.unsqueeze(0), dim=2)
        exp_hij = torch.exp(hij / self.temperature)
        return exp_hij

    def _computeBatchH_ij(self, Z):
        hij = F.cosine_similarity(Z.unsqueeze(2), Z.unsqueeze(1), dim=3)
        exp_hij = torch.exp(hij / self.temperature)
        return exp_hij

    def _computeBatchH_x_xk(self, z, zk):
        hij = F.cosine_similarity(z.unsqueeze(1), zk, dim=2)
        exp_hij = torch.exp(hij / self.temperature)
        return exp_hij

    def score(self, X: torch.Tensor):
        X = X.to(self.device)
        Xk = self._computeX_k(X)
        Xk = Xk.permute((1, 0, 2))
        Zk = self.enc(Xk)
        Zk = F.normalize(Zk, dim=-1)
        Z = self.enc(X)
        Z = F.normalize(Z, dim=-1)
        Hij = self._computeBatchH_ij(Zk)
        Hx_xk = self._computeBatchH_x_xk(Z, Zk)
        mask_not_k = (~torch.eye(self.n_transforms, dtype=torch.bool, device=self.device)).float()
        numerator = Hx_xk
        denominator = Hx_xk + (mask_not_k * Hij).sum(dim=2)
        scores_V = numerator / denominator
        score_V = (-torch.log(scores_V)).sum(dim=1)
        return score_V

    def forward(self, X: torch.Tensor):
        return self.score(X)

    def fit(self, X, y=None):
        
        dataset = CustomDataset(X, y if y is not None else np.zeros(X.shape[0]))
        train_loader = DataLoader(dataset, batch_size=128, shuffle=True)
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        self.train()
        for epoch in range(100):
            total_loss = 0
            for data, _, _ in train_loader:
                data = data.float().to(self.device)
                optimizer.zero_grad()
                loss = self.score(data).mean()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            

    def predict(self, X):
        self.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            scores = self.score(X).cpu().numpy()
        threshold = np.percentile(scores, 95)
        return (scores > threshold).astype(int)

    def decision_function(self, X):
        self.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            return self.score(X).cpu().numpy()