import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
from sklearn.metrics import roc_auc_score, matthews_corrcoef
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]).float(), torch.tensor(self.labels[idx]).float(), torch.tensor([0])

class MLP(nn.Module):
    def __init__(self, input_dim, num_classes=1, num_hidden_nodes=128):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_hidden_nodes = num_hidden_nodes
        activ = nn.ReLU(True)
        self.feature_extractor = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(self.input_dim, self.num_hidden_nodes)),
            ('relu1', activ)]))
        self.size_final = self.num_hidden_nodes
        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.size_final, self.num_classes))]))

    def forward(self, input):
        features = self.feature_extractor(input)
        logits = self.classifier(features.view(-1, self.size_final))
        return logits

class DROCCTrainer:
    name = "DROCC"

    def __init__(self, in_dim, hid_dim=128, lamda=1.0, radius=0.2, gamma=2.0, lr=0.001, batch_size=128, num_epochs=50, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = MLP(input_dim=in_dim, num_hidden_nodes=hid_dim).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.lamda = lamda
        self.radius = radius
        self.gamma = gamma
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.ascent_step_size = 0.001
        self.ascent_num_steps = 50

    def one_class_adv_loss(self, x_train_data):
        batch_size = len(x_train_data)
        x_adv = torch.randn(x_train_data.shape).to(self.device).detach().requires_grad_()
        x_adv_sampled = x_adv + x_train_data

        for step in range(self.ascent_num_steps):
            with torch.enable_grad():
                new_targets = torch.zeros(batch_size).to(self.device).float()
                logits = self.model(x_adv_sampled).squeeze()
                new_loss = F.binary_cross_entropy_with_logits(logits, new_targets)
                grad = torch.autograd.grad(new_loss, [x_adv_sampled])[0]
                grad_norm = torch.norm(grad, p=2, dim=tuple(range(1, grad.dim())))
                grad_norm = grad_norm.view(-1, *[1] * (grad.dim() - 1))
                grad_normalized = grad / (grad_norm + 1e-8)
            with torch.no_grad():
                x_adv_sampled.add_(self.ascent_step_size * grad_normalized)
            if (step + 1) % 10 == 0:
                h = x_adv_sampled - x_train_data
                norm_h = torch.sqrt(torch.sum(h ** 2, dim=tuple(range(1, h.dim()))))
                alpha = torch.clamp(norm_h, self.radius, self.gamma * self.radius).to(self.device)
                proj = (alpha / (norm_h + 1e-8)).view(-1, *[1] * (h.dim() - 1))
                h = proj * h
                x_adv_sampled = x_train_data + h

        adv_pred = self.model(x_adv_sampled).squeeze()
        adv_loss = F.binary_cross_entropy_with_logits(adv_pred, torch.zeros_like(adv_pred))
        return adv_loss

    def fit(self, X, y=None):
        dataset = CustomDataset(X, y if y is not None else np.zeros(X.shape[0]))
        train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        only_ce_epochs = self.num_epochs // 2

        def adjust_lr(epoch, total_epochs, only_ce_epochs, lr, optimizer):
            epoch = epoch - only_ce_epochs
            drocc_epochs = total_epochs - only_ce_epochs
            if epoch > drocc_epochs:
                lr = lr * 0.001
            elif epoch > 0.90 * drocc_epochs:
                lr = lr * 0.01
            elif epoch > 0.60 * drocc_epochs:
                lr = lr * 0.1
            elif epoch > 0.30 * drocc_epochs:
                lr = lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        self.model.train()
        best_score = -np.inf
        best_model = None
        for epoch in range(self.num_epochs):
            adjust_lr(epoch, self.num_epochs, only_ce_epochs, self.optimizer.param_groups[0]['lr'], self.optimizer)
            epoch_ce_loss = 0
            epoch_adv_loss = 0
            batch_idx = 0

            for data, target, _ in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                data = data.float()
                target = target.float().squeeze()
                batch_idx += 1

                self.optimizer.zero_grad()
                logits = self.model(data).squeeze()
                ce_loss = F.binary_cross_entropy_with_logits(logits, target)
                epoch_ce_loss += ce_loss.item()

                if epoch >= only_ce_epochs:
                    data_normal = data[target == 0]
                    if len(data_normal) > 0:
                        adv_loss = self.one_class_adv_loss(data_normal)
                        epoch_adv_loss += adv_loss.item()
                        loss = ce_loss + adv_loss * self.lamda
                    else:
                        loss = ce_loss
                else:
                    loss = ce_loss

                loss.backward()
                self.optimizer.step()

            epoch_ce_loss /= batch_idx
            epoch_adv_loss /= batch_idx if epoch >= only_ce_epochs else 1

            
            if y is not None:
                test_dataset = CustomDataset(X, y)
                test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
                test_scores = self.test(test_loader)
                if test_scores['AUC'] > best_score:
                    best_score = test_scores['AUC']
                    best_model = self.model.state_dict()

        if best_model:
            self.model.load_state_dict(best_model)

    def test(self, test_loader):
        self.model.eval()
        label_score = []
        with torch.no_grad():
            for data, target, _ in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                data = data.float()
                target = target.float().squeeze()
                logits = self.model(data).squeeze()
                scores = logits
                label_score += list(zip(target.cpu().numpy().reshape(-1).tolist(),
                                        scores.cpu().numpy().reshape(-1).tolist()))

        labels, scores = zip(*label_score)
        labels = np.array(labels)
        scores = np.array(scores)
        auc_score = roc_auc_score(labels, scores)
        thresh = np.percentile(scores, 80)
        y_pred = np.where(scores >= thresh, 1, 0)
        mcc_score = matthews_corrcoef(labels, y_pred)
        return {'AUC': auc_score, 'MCC': mcc_score}

    def predict(self, X):
        self.model.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            logits = self.model(X).squeeze().cpu().numpy()
        threshold = np.percentile(logits, 80)
        return (logits >= threshold).astype(int)

    def decision_function(self, X):
        self.model.eval()
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            return self.model(X).squeeze().cpu().numpy()