import collections
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import math
import argparse
from utils_tool import *
from dataprocess import *
INF = 1e20
VERY_SMALL_NUMBER = 1e-12
from typing import Optional, Callable, Tuple, Dict, Union
from itertools import islice, chain
import re
from collections import OrderedDict, defaultdict
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler
import sys
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, KernelPCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import normalize

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X


def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                          value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    """Scaled dot product attention."""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


def transpose_qkv(X, num_heads):
    """Transposition for parallel computation of multiple attention heads."""
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1],
                  X.shape[2])  
    X = X.permute(0, 2, 1, 3)  
    return X.reshape(X.shape[0], X.shape[1], -1) 


class MultiHeadAttention(nn.Module):
    """Multi-head attention."""

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)


class MLP(nn.Module):
    def __init__(self, input_size, hidden_layers,
                 dropout=0.0, batchnorm=True, activation=True):
        super(MLP, self).__init__()
        modules = OrderedDict()
        previous_size = input_size
        for index, hidden_layer in enumerate(hidden_layers):
            modules[f"dense{index}"] = nn.Linear(previous_size, hidden_layer)
            if batchnorm:
                modules[f"batchnorm{index}"] = nn.BatchNorm1d(hidden_layer)
            if activation:
                modules[f"activation{index}"] = nn.ReLU()
            if dropout:
                modules[f"dropout{index}"] = nn.Dropout(dropout)
            previous_size = hidden_layer
        self.mlp = nn.Sequential(modules)

    def forward(self, x):
        return self.mlp(x)


class PositionalEmbedding(nn.Module):
    """
    Computes positional embedding following "Attention is all you need"
    """
    def __init__(self, d_model, max_len=300):
        super(PositionalEmbedding, self).__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x, x_len):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)

class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(
                10000,
                torch.arange(0, num_hiddens, 2, dtype=torch.float32) /
                num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X): 
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

from torch.nn.init import xavier_uniform_
def get_key_padding_mask(tokens, padding_token=0):
    key_padding_mask = torch.zeros(tokens.size())
    key_padding_mask[tokens == padding_token] = -float('inf')
    return key_padding_mask

class TransformerEncoderwithPE(nn.Module):
    def __init__(self, d_model, num_encoder_layers, dropout):
        super(TransformerEncoderwithPE, self).__init__()
        self.d_model = d_model
        self.positional_embedding = PositionalEncoding(d_model, dropout)
        transfomerlayer = nn.TransformerEncoderLayer(d_model, 8, dim_feedforward=d_model * 4, dropout=dropout)
        encoder_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(transfomerlayer, num_encoder_layers, encoder_norm)

    def forward(self, x, valid_lens):
        lens = valid_lens.unsqueeze(1)
        key_padding_mask = torch.arange(x.size(1)).expand(x.size(0), x.size(1)).to(x.device) < lens
        key_padding_mask = ~key_padding_mask
        x = self.positional_embedding(x)
        x = x.transpose(0, 1) 
        x = self.encoder(x, src_key_padding_mask=key_padding_mask)
        x = x.transpose(0, 1)
        return x


def extract_last_valid_output(output, sequence_lengths):
    max_len = output.size(1)
    batch_size = output.size(0)
    hidden_size = output.size(2)
    device = output.device

    last_valid_index = sequence_lengths - 1

    output = output.reshape(batch_size * max_len, hidden_size)

    indices = last_valid_index + max_len * torch.arange(batch_size).to(device)

    last_valid_output = output[indices]

    last_valid_output = last_valid_output.reshape(batch_size, hidden_size)

    return last_valid_output

class AttentionNetPooling(nn.Module):
    def __init__(self, in_dim):
        super(AttentionNetPooling, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.LayerNorm(in_dim),
            nn.GELU(),
            nn.Linear(in_dim, 1),
        )

    def forward(self, last_hidden_state, attention_mask):
        w = self.attention(last_hidden_state).float() 
        w[attention_mask==0]=float('-inf')
        w = torch.softmax(w,1) 
        attention_embeddings = torch.sum(w * last_hidden_state, dim=1)
        return attention_embeddings, w 

class Seq(nn.Module):
    def __init__(self, num_embeddings, seq_embedding_size, hidden_size, dropout, reduction=False,
                 LayerNorm=True, encoder_type='lstm', num_layers=2, use_att_cons=False):
        super(Seq, self).__init__()
        self.reduction = reduction
        self.hidden_size = hidden_size
        self.LayerNorm = LayerNorm
        self.embeddings = OrderedDict()
        self.num_layers = num_layers
        self.num_embeddings = num_embeddings
        self.embeddings['hist_activity'] = nn.Embedding(self.num_embeddings, seq_embedding_size,
                                                        padding_idx=0)
        self.add_module(f"embedding:{'hist_activity'}", self.embeddings['hist_activity'])
        self.encoder_type = encoder_type
        if self.encoder_type == "lstm":
            self.encoder = nn.LSTM(input_size=seq_embedding_size, hidden_size=hidden_size,
                                   num_layers=num_layers, batch_first=True, bidirectional=False, dropout=dropout)
        elif self.encoder_type == "gru":
            self.encoder = nn.GRU(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=False, dropout=dropout)
        elif self.encoder_type == "rnn":
            self.encoder = nn.RNN(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=False, dropout=dropout)
        elif self.encoder_type == "transformer":
            self.encoder = TransformerEncoderwithPE(d_model=hidden_size, num_encoder_layers=num_layers, dropout=dropout)
        else:
            raise NotImplementedError
        if LayerNorm:
            self.ln = nn.LayerNorm(hidden_size)
        if self.reduction == 'selfattention+avgpooling' or self.reduction == 'lastpositionattention':
            self.SA = MultiHeadAttention(hidden_size, hidden_size, hidden_size, hidden_size, 8,
                                         dropout)
        if self.reduction == 'clsattention':
            q_t = np.random.normal(loc=0.0, scale=1, size=(1, self.hidden_size))
            self.q = nn.Parameter(torch.from_numpy(q_t).float())
            self.MHA = MultiHeadAttention(hidden_size, hidden_size, hidden_size, hidden_size, 8,
                                         dropout)
        if self.reduction == 'attentionnetpooling':
            self.ANP = AttentionNetPooling(self.hidden_size)

    def forward(self, x):
        seqs_length = x['max_len']
        output = self.embeddings['hist_activity'](x['hist_activity'])
        device = output.device
        if self.encoder_type == "lstm" or self.encoder_type == "rnn" or self.encoder_type == "gru":
            packed_out = pack_padded_sequence(
                output,
                lengths=seqs_length.cpu(),
                batch_first=True,
                enforce_sorted=False)
            output, (h_n, c_n) = self.encoder(packed_out)
            output, _ = pad_packed_sequence(output, batch_first=True)
        elif self.encoder_type == "transformer":
            output = self.encoder(output, valid_lens=seqs_length)
            h_n = [None, None]
        else:
            raise NotImplementedError
        if self.LayerNorm:
            output = self.ln(output)
        if self.reduction:
            if self.reduction == True or self.reduction == 'avgpooling':
                lens = seqs_length.unsqueeze(1)
                padding_mask = torch.arange(output.size(1)).expand(output.size(0), output.size(1)).to(
                    device) < lens
                padding_mask = padding_mask.unsqueeze(-1).float()
                output = torch.sum(output * padding_mask, dim=1) / lens
            elif self.reduction == 'selfattention+avgpooling':
                output = self.SA(output, output, output, seqs_length)
                lens = seqs_length.unsqueeze(1)
                padding_mask = torch.arange(output.size(1)).expand(output.size(0), output.size(1)).to(
                    device) < lens
                padding_mask = padding_mask.unsqueeze(-1).float()
                output = torch.sum(output * padding_mask, dim=1) / lens
            elif self.reduction == 'lastpositionattention':
                output = self.SA(output, output, output, seqs_length)
                output = extract_last_valid_output(output, seqs_length)
            elif self.reduction == 'clsattention':
                query = self.q.repeat(output.size(0), 1)  
                query = query.unsqueeze(1) 
                lens = seqs_length.unsqueeze(1)
                attention_mask = torch.arange(output.size(1)).expand(output.size(0), output.size(1)).to(
                    device) < lens
                output = self.MHA(query, output, output, seqs_length)
                output = output.squeeze()
            elif self.reduction == 'attentionnetpooling':
                lens = seqs_length.unsqueeze(1)
                attention_mask = torch.arange(output.size(1)).expand(output.size(0), output.size(1)).to(device) < lens
                output = self.ANP(output, attention_mask)
            else:
                raise NotImplementedError
        return output, h_n[-1]

def CreateSeq(num_embeddings, seq_embedding_size=128, hidden_size=128, dropout=0.3, reduction=False,
              LayerNorm=True, encoder_type='lstm', num_layers=2):
    return Seq(num_embeddings, seq_embedding_size, hidden_size, dropout, reduction, LayerNorm,
               encoder_type, num_layers)

class Seq_Post_hoc_Backup(nn.Module):
    def __init__(self, cat_nums, seq_embedding_size, hidden_size, dropout, reduction=False, use_attention=False,
                 LayerNorm=True, encoder_type='lstm', num_layers=2):
        super(Seq_Post_hoc, self).__init__()
        self.reduction = reduction
        self.hidden_size = hidden_size
        self.use_attention = use_attention
        self.LayerNorm = LayerNorm
        self.embeddings = OrderedDict()
        self.num_layers = num_layers
        self.cat_nums = cat_nums
        self.embeddings['hist_activity'] = nn.Embedding((cat_nums['hist_activity'] - 3) * 24 + 25 + 1, seq_embedding_size,
                                                        padding_idx=0) 
        self.add_module(f"embedding:{'hist_activity'}", self.embeddings['hist_activity'])
        self.encoder_type = encoder_type
        if self.encoder_type == "lstm":
            self.encoder = nn.LSTM(input_size=seq_embedding_size, hidden_size=hidden_size,
                                   num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "gru":
            self.encoder = nn.GRU(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "rnn":
            self.encoder = nn.RNN(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "transformer":
            self.encoder = TransformerEncoderwithPE(d_model=hidden_size, num_encoder_layers=num_layers, dropout=dropout)
        else:
            raise NotImplementedError
        if self.encoder_type == "rnn" or self.encoder_type == "lstm" or self.encoder_type == "gru":
            self.fc = nn.Linear(hidden_size*2, hidden_size)
        if LayerNorm:
            self.ln = nn.LayerNorm(hidden_size)
        if use_attention or self.reduction == 'selfattention+avgpooling' or self.reduction == 'lastpositionattention':
            self.SA = MultiHeadAttention(hidden_size, hidden_size, hidden_size, hidden_size, 8,
                                         dropout)
        if self.reduction == 'clsattention':
            q_t = np.random.normal(loc=0.0, scale=1, size=(1, self.hidden_size))
            self.q = nn.Parameter(torch.from_numpy(q_t).float())
            self.MHA = MultiHeadAttention(hidden_size, hidden_size, hidden_size, hidden_size, 8,
                                         dropout)
        if self.reduction == 'attentionnetpooling':
            self.ANP = AttentionNetPooling(self.hidden_size)

    def forward(self, x):
        seqs_length = x['max_len']
        output = self.embeddings['hist_activity'](x['hist_activity'])
        if self.encoder_type == "lstm" or self.encoder_type == "rnn" or self.encoder_type == "gru":
            packed_out = pack_padded_sequence(
                output,
                lengths=seqs_length.cpu(),
                batch_first=True,
                enforce_sorted=False)
            if self.encoder_type == "lstm":
                output, (h_n, c_n) = self.encoder(packed_out)
            elif self.encoder_type == "rnn" or self.encoder_type == "gru":
                output, h_n = self.encoder(packed_out)
            output, _ = pad_packed_sequence(output, batch_first=True)
            output = self.fc(output)
        elif self.encoder_type == "transformer":
            output = self.encoder(output, valid_lens=seqs_length)
            h_n = [None, None]
        else:
            raise NotImplementedError
        if self.use_attention:
            print("已弃用")
            raise NotImplementedError
        if self.LayerNorm:
            output = self.ln(output)
        if self.reduction:
            if self.reduction == True or self.reduction == 'avgpooling':
                output = output.gather(1, mask_loc.unsqueeze(1).unsqueeze(1).repeat(1, 1, output.shape[-1])).squeeze()
            elif self.reduction == 'lastpositionattention':
                query = output.gather(1, mask_loc.unsqueeze(1).unsqueeze(1).repeat(1, 1, output.shape[-1]))
                output = self.SA(query, output, output, seqs_length).squeeze()
            else:
                raise NotImplementedError
        return output, h_n[-1]

class Seq_Post_hoc(nn.Module):
    def __init__(self, num_embeddings, seq_embedding_size, hidden_size, dropout, reduction=None,
                 LayerNorm=True, encoder_type='lstm', num_layers=2, use_att_cons=False):
        super(Seq_Post_hoc, self).__init__()
        self.reduction = reduction
        self.hidden_size = hidden_size
        self.LayerNorm = LayerNorm
        self.embeddings = OrderedDict()
        self.num_layers = num_layers
        self.num_embeddings = num_embeddings
        self.embeddings['hist_activity'] = nn.Embedding(num_embeddings, seq_embedding_size, padding_idx=0)
        self.add_module(f"embedding:{'hist_activity'}", self.embeddings['hist_activity'])
        self.encoder_type = encoder_type
        if self.encoder_type == "lstm" or self.encoder_type == "rnn" or self.encoder_type == "gru":
            self.d_model = hidden_size*2
        elif self.encoder_type == "transformer":
            self.d_model = hidden_size
        else:
            raise NotImplementedError
        if self.encoder_type == "lstm":
            self.encoder = nn.LSTM(input_size=seq_embedding_size, hidden_size=hidden_size,
                                   num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "gru":
            self.encoder = nn.GRU(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "rnn":
            self.encoder = nn.RNN(input_size=seq_embedding_size, hidden_size=hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        elif self.encoder_type == "transformer":
            self.encoder = TransformerEncoderwithPE(d_model=hidden_size, num_encoder_layers=num_layers, dropout=dropout)
        else:
            raise NotImplementedError
        if LayerNorm:
            self.ln = nn.LayerNorm(self.d_model)
        if self.reduction == 'selfattention+avgpooling' or self.reduction == 'lastpositionattention':
            self.SA = MultiHeadAttention(self.d_model, self.d_model, self.d_model, self.d_model, 8,
                                         dropout)
        if self.reduction == 'clsattention':
            q_t = np.random.normal(loc=0.0, scale=1, size=(1, self.hidden_size))
            self.q = nn.Parameter(torch.from_numpy(q_t).float())
            self.MHA = MultiHeadAttention(self.d_model, self.d_model, self.d_model, self.d_model, 8,
                                         dropout)
        if self.reduction == 'attentionnetpooling':
            self.ANP = AttentionNetPooling(self.d_model)
        self.use_att_cons = use_att_cons
        if self.use_att_cons:
            self.ANP = AttentionNetPooling(self.d_model)

    def forward(self, x, seqs_length):
        output = self.embeddings['hist_activity'](x)
        if self.encoder_type == "lstm" or self.encoder_type == "rnn" or self.encoder_type == "gru":
            packed_out = pack_padded_sequence(
                output,
                lengths=seqs_length.cpu(),
                batch_first=True,
                enforce_sorted=False)
            if self.encoder_type == "lstm":
                output, (h_n, c_n) = self.encoder(packed_out)
            elif self.encoder_type == "rnn" or self.encoder_type == "gru":
                output, h_n = self.encoder(packed_out)
            output, _ = pad_packed_sequence(output, batch_first=True)
        elif self.encoder_type == "transformer":
            output = self.encoder(output, valid_lens=seqs_length)
            h_n = [None, None]
        else:
            raise NotImplementedError
        output_enc = output
        if self.LayerNorm:
            output = self.ln(output)
        if self.reduction:
            if self.reduction == 'lastpositionattention':
                output = self.SA(output, output, output, seqs_length)
            else:
                raise NotImplementedError
        if self.use_att_cons:
            lens = seqs_length.unsqueeze(1)
            attention_mask = torch.arange(output_enc.size(1)).expand(output_enc.size(0), output_enc.size(1)).to(device) < lens
            output_seq, w = self.ANP(output_enc, attention_mask)
            return output_enc, output, output_seq, w
        return output_enc, output

def CreateSeq_Post_hoc(num_embeddings, seq_embedding_size=128, hidden_size=128, dropout=0.3, reduction=False,
              LayerNorm=True, encoder_type='lstm', num_layers=2):
    return Seq_Post_hoc(num_embeddings, seq_embedding_size, hidden_size, dropout, reduction, LayerNorm,
               encoder_type, num_layers)

class SeqwithClassifier(nn.Module):
    def __init__(self, num_embeddings, seq_embedding_size, hidden_size, dropout, reduction=None,
                 LayerNorm=True, encoder_type='lstm', num_layers=2, num_class=1, prediction_mode='post_hoc', num_spheres=-1, use_prototypes=False, similarity_type='l2', use_att_cons=False):
        super(SeqwithClassifier, self).__init__()
        if prediction_mode == 'real_time':
            self.seq = Seq(num_embeddings, seq_embedding_size, hidden_size, dropout, reduction, LayerNorm,
                           encoder_type, num_layers, use_att_cons=use_att_cons)
        elif prediction_mode == 'post_hoc':
            self.seq = Seq_Post_hoc(num_embeddings, seq_embedding_size, hidden_size, dropout, reduction,
                                    LayerNorm, encoder_type, num_layers, use_att_cons=use_att_cons)
        else:
            raise NotImplementedError
        self.d_model = self.seq.d_model
        self.final_layer = nn.Linear(self.seq.d_model, num_class) 
        self.final_layer.apply(init_weights)
        self.use_prototypes = use_prototypes
        self.use_att_cons = use_att_cons
        if self.use_att_cons:
            self.final_layer2 = nn.Linear(self.seq.d_model,
                                         num_class)
            self.final_layer2.apply(init_weights)
        if self.use_prototypes:
            assert num_spheres > 0
            self.num_spheres = num_spheres
            self.similarity_type = similarity_type
            self.multi_hyperspheres = multi_hypersphere_prototypes(self.num_spheres, self.d_model, self.similarity_type)


    def forward(self, x, seqs_length):
        batch_size = x.shape[0]
        if self.use_att_cons:
            out_enc, out_final, output_seq, w = self.seq(x, seqs_length) 
        else:
            out_enc, out_final = self.seq(x, seqs_length) 
        logits = self.final_layer(out_final.reshape(-1, out_final.shape[-1])) 
        output = {}
        output['logits'] = logits.reshape(batch_size, -1, logits.shape[-1]) 
        output['feature'] = out_enc 
        if self.use_att_cons:
            logits_seq = self.final_layer2(output_seq.reshape(-1, output_seq.shape[-1])) 
            output['logits_seq'] = logits_seq  
            output['attention_score'] = w  
        return output

def CreateSeqwithClassifier(num_embeddings, seq_embedding_size=128, hidden_size=128, dropout=0.3, reduction=None,
                            LayerNorm=True, encoder_type='lstm', num_layers=2, num_class=2, prediction_mode='post_hoc'):
    return SeqwithClassifier(num_embeddings, seq_embedding_size, hidden_size, dropout, reduction,
                             LayerNorm, encoder_type, num_layers, num_class, prediction_mode=prediction_mode)


def init_weights(model):
    if isinstance(model, nn.Linear):
        if model.weight is not None:
            nn.init.kaiming_uniform_(model.weight.data)
        if model.bias is not None:
            nn.init.normal_(model.bias.data)
    elif isinstance(model, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        if model.weight is not None:
            nn.init.normal_(model.weight.data, mean=1, std=0.02)
        if model.bias is not None:
            nn.init.constant_(model.bias.data, 0)
    else:
        pass

def get_cosine_schedule_with_warmup(optimizer,
                                    num_training_steps,
                                    num_cycles=7. / 16.,
                                    num_warmup_steps=0,
                                    last_epoch=-1):
    '''
    Get cosine scheduler (LambdaLR).
    if warmup is needed, set num_warmup_steps (int) > 0.
    '''
    from torch.optim.lr_scheduler import LambdaLR
    def _lr_lambda(current_step):
        '''
        _lr_lambda returns a multiplicative factor given an interger parameter epochs.
        Decaying criteria: last_epoch
        '''

        if current_step < num_warmup_steps:
            _lr = float(current_step) / float(max(1, num_warmup_steps))
        else:
            num_cos_steps = float(current_step - num_warmup_steps)
            num_cos_steps = num_cos_steps / float(max(1, num_training_steps - num_warmup_steps))
            _lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps))
        return _lr

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

def generate_padding_mask(max_len, valid_len, device):
    mask = torch.arange(max_len, dtype=torch.float32).to(device)[None, :] < valid_len[:, None]
    return mask

def smooth(arr, valid_len, weight):
    valid_len = valid_len.unsqueeze(1)
    bs, max_len = arr.shape
    padding_mask_1 = torch.arange(max_len).expand(bs, max_len).to(device) < (valid_len - 1)
    padding_mask_1 = padding_mask_1[:, :-1].to(torch.float) 
    arr1 = arr[:, :-1] 
    arr1 = arr1 * padding_mask_1
    padding_mask_2 = torch.arange(max_len).expand(bs, max_len).to(device) < (valid_len)
    padding_mask_2 = padding_mask_2[:, 1:].to(torch.float) 
    arr2 = arr[:, 1:] 
    arr2 = arr2 * padding_mask_2
    loss = torch.sum((arr2-arr1)**2)
    return weight*loss


def sparsity(arr, valid_len, weight):
    valid_len = valid_len.unsqueeze(1)  
    bs, max_len = arr.shape
    padding_mask = torch.arange(max_len).expand(bs, max_len).to(device) < valid_len
    padding_mask = padding_mask.to(torch.float)  
    arr = arr * padding_mask
    loss = torch.sum(arr)/(padding_mask.sum())
    return weight*loss

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=200.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) 
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive


class SigmoidCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidCrossEntropyLoss, self).__init__()

    def forward(self, x, target):
        tmp = 1 + torch.exp(- torch.abs(x))
        return torch.abs(torch.mean(- x * target + torch.clamp(x, min=0) + torch.log(tmp)))


class MILloss(torch.nn.Module):
    def __init__(self, const_loss_weight):
        super(MILloss, self).__init__()
        self.const_loss_weight = const_loss_weight
        self.sigmoid = torch.nn.Sigmoid()
        self.criterion = torch.nn.BCELoss()
        self.contrastive = ContrastiveLoss()

    def forward(self, score_normal, score_abnormal, nlabel, alabel, nor_feamagnitude, abn_feamagnitude):
        device = score_abnormal.device
        label = torch.cat((nlabel, alabel), 0)
        label = label.to(device)
        score_abnormal = score_abnormal
        score_normal = score_normal
        score = torch.cat((score_normal, score_abnormal), 0) 
        score = score.squeeze() 
        seperate = len(abn_feamagnitude) / 2 

        loss_cls = self.criterion(score, label.float())
        loss_con = self.contrastive(torch.norm(abn_feamagnitude, p=1, dim=2), torch.norm(nor_feamagnitude, p=1, dim=2),
                                    1)  
        loss_con_n = self.contrastive(torch.norm(nor_feamagnitude[int(seperate):], p=1, dim=2),
                                      torch.norm(nor_feamagnitude[:int(seperate)], p=1, dim=2),
                                      0)  
        loss_con_a = self.contrastive(torch.norm(abn_feamagnitude[int(seperate):], p=1, dim=2),
                                      torch.norm(abn_feamagnitude[:int(seperate)], p=1, dim=2), 0)
        loss_total = loss_cls + self.const_loss_weight * (loss_con + loss_con_a + loss_con_n)
        output_loss = {}
        output_loss['total_loss'] = loss_total
        output_loss['cls_loss'] = loss_cls
        output_loss['loss_con'] = loss_con
        output_loss['loss_con_a'] = loss_con_a
        output_loss['loss_con_n'] = loss_con_n
        output_loss['contrastive_loss'] = self.const_loss_weight * (loss_con + loss_con_a + loss_con_n)
        return output_loss

def topk_rank_loss(score_normal, score_abnormal):
    bsdiv2 = score_normal.shape[0]
    loss = 0
    for i in range(bsdiv2):
        loss += torch.sum(F.relu(1-score_abnormal+score_normal[i]))
    loss = loss / (bsdiv2)**2
    return loss

class multi_hypersphere_prototypes(nn.Module):
    def __init__(self, num_spheres, feature_size, similarity_type):
        super(multi_hypersphere_prototypes, self).__init__()
        self.num_spheres = num_spheres
        self.feature_size = feature_size
        self.prototypes_centers = nn.Parameter(torch.empty(self.num_spheres, self.feature_size)) 
        self.reset_parameters()
        self.similarity_type = similarity_type 

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.prototypes_centers.size(-1))
        self.prototypes_centers.data.uniform_(-stdv, stdv)

    def get_score(self, x, use_detach=False): 
        _, feature_size = x.size()  
        if use_detach:
            centers = self.prototypes_centers.detach()
        else:
            centers = self.prototypes_centers
        if self.similarity_type == 'inner_product': 
            score = torch.matmul(x, torch.t(centers))  
        elif self.similarity_type == 'l2': 
            score = torch.cdist(x, centers, p=2) 
        else:
            raise NotImplementedError
        return score

    def top1_score(self, x, use_detach=False): 
        bs, max_len, feature_size = x.size()
        score = self.get_score(x.reshape(-1, feature_size), use_detach=use_detach) 
        top1_prototype_score, _ = torch.topk(score, k=1, dim=1, largest=False)  
        return top1_prototype_score.reshape(bs, -1) 

    def cal_softmax_score(self, score): 
        dist_prototype2querys = F.softmax(score, dim=0)  
        dist_query2prototypes = F.softmax(score, dim=1) 
        return dist_prototype2querys, dist_query2prototypes

    def compactness_loss(self, x, use_detach=False):
        feature_size = x.shape[-1]
        x_reshape = x.reshape(-1, feature_size)
        mse = nn.MSELoss()
        score = self.get_score(x_reshape, use_detach=use_detach) 
        softmax_score_prototype2querys, softmax_score_query2prototypes = self.cal_softmax_score(score)
        _, top1_prototype_indices = torch.topk(softmax_score_query2prototypes, k=1, dim=1, largest=False) 
        if use_detach: 
            compactness_loss = torch.sum((x_reshape - self.prototypes_centers[top1_prototype_indices.squeeze(1)].detach()) ** 2, dim=-1).mean()
        else:
            compactness_loss = torch.sum((x_reshape - self.prototypes_centers[top1_prototype_indices.squeeze(1)]) ** 2, dim=-1).mean()
        return compactness_loss 

    def gcpct_loss(self, x, use_detach=False):
        feature_size = x.shape[-1]
        x_reshape = x.reshape(-1, feature_size)
        if use_detach: 
            mean_prototypes_centers = self.prototypes_centers.detach().mean(dim=0)
        else:
            mean_prototypes_centers = self.prototypes_centers.mean(dim=0)
        return torch.sum((x_reshape - mean_prototypes_centers) ** 2, dim=-1).mean() 

    def separateness_loss(self, x, use_detach=False):
        feature_size = x.shape[-1]  
        x_reshape = x.reshape(-1, feature_size)
        score = self.get_score(x.reshape(-1, feature_size))
        softmax_score_prototype2querys, softmax_score_query2prototypes = self.cal_softmax_score(score)
        _, top2_prototype_indices = torch.topk(softmax_score_query2prototypes, 2, dim=1, largest=False)
        pos = self.prototypes_centers[top2_prototype_indices[:, 0]] 
        neg = self.prototypes_centers[top2_prototype_indices[:, 1]] 
        if use_detach: 
            separateness_loss = SoftTripletLoss()(x_reshape, pos.detach(), neg.detach())
        else:
            separateness_loss = SoftTripletLoss()(x_reshape, pos, neg)
        return separateness_loss

    def neg_feedback_loss(self, x, use_detach=False):
        feature_size = x.shape[-1] 
        score = self.get_score(x.reshape(-1, feature_size), use_detach=use_detach)
        epsilon = 1e-8
        score = (score**2 + epsilon) ** (-1)
        if score.shape[0] == 0:
            return torch.tensor(0.0).to(x.device)
        return score.mean() 

class SoftTripletLoss(nn.Module):
    def __init__(self):
        super(SoftTripletLoss, self).__init__()
        self.eps = 1e-9  

    def forward(self, anchor, positive, negative):
        d = anchor.shape[-1]
        d_pos = (anchor - positive).pow(2).sum(1)  
        d_neg = (anchor - negative).pow(2).sum(1)  

        softmax_triplet = torch.exp(d_neg) / (torch.exp(d_neg) + torch.exp(d_pos) + self.eps)
        loss = torch.mean(-torch.log(softmax_triplet + self.eps))
        return loss

def smooth_targets(logits, targets, smoothing=0.1):
    with torch.no_grad():
        true_dist = torch.zeros_like(logits)
        true_dist.fill_(smoothing / (logits.shape[-1] - 1))
        true_dist.scatter_(1, targets.data.unsqueeze(1), (1 - smoothing))
    return true_dist

def compute_prob(self, logits):
    return torch.softmax(logits, dim=-1)

class FixThresholdingMaskGenerator:
    """
    Common Fixed Threshold used in fixmatch, uda, pseudo label, et. al.
    """
    def __init__(self, p_cutoff):
        self.p_cutoff = p_cutoff

    @torch.no_grad()
    def masking(self, logits_x_ulb, softmax_x_ulb=True):
        if softmax_x_ulb:
            probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1)
        else:
            probs_x_ulb = logits_x_ulb.detach()
        max_probs, _ = torch.max(probs_x_ulb, dim=-1)
        mask = max_probs.ge(self.p_cutoff).to(max_probs.dtype)
        return mask

class AdaptiveThresholdingMaskGenerator:
    def __init__(self, num_classes=1, momentum=0.999, use_quantile=True, clip_thresh=False, binary_thresh=0.5, tau_c = 0.5):
        self.num_classes = num_classes
        self.m = momentum
        self.use_quantile = use_quantile
        self.clip_thresh = clip_thresh
        if self.num_classes == 1:
            self.binary_thresh = binary_thresh
            self.tau_c = tau_c
        else:
            self.p_model = torch.ones((self.num_classes)) / self.num_classes 
            self.time_p = self.p_model.mean() 

def param_groups_weight_decay(
        model: nn.Module,
        weight_decay=1e-5,
        no_weight_decay_list=()
):
    no_weight_decay_list = set(no_weight_decay_list)
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            no_decay.append(param)
        else:
            decay.append(param)

    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]


def _group(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def _layer_map(model, layers_per_group=12, num_groups=None):
    def _in_head(n, hp):
        if not hp:
            return True
        elif isinstance(hp, (tuple, list)):
            return any([n.startswith(hpi) for hpi in hp])
        else:
            return n.startswith(hp)

    head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
    names_trunk = []
    names_head = []
    for n, _ in model.named_parameters():
        names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)

    num_trunk_layers = len(names_trunk)
    if num_groups is not None:
        layers_per_group = -(num_trunk_layers // -num_groups)
    names_trunk = list(_group(names_trunk, layers_per_group))

    num_trunk_groups = len(names_trunk)
    layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
    layer_map.update({n: num_trunk_groups for n in names_head})
    return layer_map

MATCH_PREV_GROUP = (99999,)

def group_with_matcher(
        named_objects,
        group_matcher: Union[Dict, Callable],
        output_values: bool = False,
        reverse: bool = False
):
    if isinstance(group_matcher, dict):
        compiled = []
        for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
            if mspec is None:
                continue
            if isinstance(mspec, (tuple, list)):
                for sspec in mspec:
                    compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
            else:
                compiled += [(re.compile(mspec), (group_ordinal,), None)]
        group_matcher = compiled

    def _get_grouping(name):
        if isinstance(group_matcher, (list, tuple)):
            for match_fn, prefix, suffix in group_matcher:
                r = match_fn.match(name)
                if r:
                    parts = (prefix, r.groups(), suffix)
                    return tuple(map(float, chain.from_iterable(filter(None, parts))))
            return float('inf'),  
        else:
            ord = group_matcher(name)
            if not isinstance(ord, collections.abc.Iterable):
                return ord,
            return tuple(ord)

    grouping = defaultdict(list)
    for k, v in named_objects:
        grouping[_get_grouping(k)].append(v if output_values else k)

    layer_id_to_param = defaultdict(list)
    lid = -1
    for k in sorted(filter(lambda x: x is not None, grouping.keys())):
        if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
            lid += 1
        layer_id_to_param[lid].extend(grouping[k])

    if reverse:
        assert not output_values, "reverse mapping only sensible for name output"
        param_to_layer_id = {}
        for lid, lm in layer_id_to_param.items():
            for n in lm:
                param_to_layer_id[n] = lid
        return param_to_layer_id

    return layer_id_to_param

def group_parameters(
        module: nn.Module,
        group_matcher,
        output_values=False,
        reverse=False,
):
    return group_with_matcher(
        module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)

def param_groups_layer_decay(
        model: nn.Module,
        lr : float=1e-3,
        weight_decay: float = 0.05,
        no_weight_decay_list: Tuple[str] = (),
        layer_decay: float = .75,
        end_layer_decay: Optional[float] = None,
):
    no_weight_decay_list = set(no_weight_decay_list)
    param_group_names = {}  
    param_groups = {}

    if hasattr(model, 'group_matcher'):
        layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
    else:
        layer_map = _layer_map(model)
    num_layers = max(layer_map.values()) + 1
    layer_max = num_layers - 1
    layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndim == 1 or name in no_weight_decay_list:
            g_decay = "no_decay"
            this_decay = 0.
        else:
            g_decay = "decay"
            this_decay = weight_decay

        layer_id = layer_map.get(name, layer_max)
        group_name = "layer_%d_%s" % (layer_id, g_decay)

        if group_name not in param_groups:
            this_scale = layer_scales[layer_id]
            param_group_names[group_name] = {
                "lr": this_scale * lr,
                "weight_decay": this_decay,
                "param_names": [],
            }
            param_groups[group_name] = {
                "lr": this_scale * lr,
                "weight_decay": this_decay,
                "params": [],
            }
        param_group_names[group_name]["param_names"].append(name)
        param_groups[group_name]["params"].append(param)
    return list(param_groups.values())

def SoftCrossEntropy(logits, target, reduction='none'):
    log_likelihood = -F.log_softmax(logits, dim=1)
    N = logits.shape[0]
    if reduction == 'mean':
        loss = torch.sum(torch.mul(log_likelihood, target)) / N
    elif reduction == 'sum':
        loss = torch.sum(torch.mul(log_likelihood, target))
    else: 
        loss = torch.sum(torch.mul(log_likelihood, target), dim=1)
    return loss

def CrossEntropyfn(logits, targets, reduction='none', type='CE', use_sigmoid=False):
    if type =='CE':
        if logits.shape == targets.shape:
            return SoftCrossEntropy(logits, targets, reduction)
        else:
            return F.cross_entropy(logits, targets, reduction=reduction)
    elif type =='BCE':
        if use_sigmoid:
            newlogits = F.sigmoid(logits)
        else:
            newlogits = logits
        return F.binary_cross_entropy(newlogits, targets, reduction=reduction)
    else:
        raise NotImplementedError

class MILtrainer:
    def __init__(self, args, model, train_loader_a, train_loader_n, train_loader_random, val_loader, val_loader_a, device, logger, ema_model=None):
        self.args = args
        self.continuetrain_MIL = args.continuetrain_MIL
        self.continuetrain_debias = args.continuetrain_debias
        self.train_loader_dict = {}
        self.train_loader_dict['normal_seq'] = train_loader_n
        self.train_loader_dict['abnormal_seq'] = train_loader_a
        self.train_loader_dict['random'] = train_loader_random
        self.val_loader = val_loader
        self.val_loader_a = val_loader_a
        self.device = device
        self.print_fn = print if logger is None else logger.info
        self.model = model
        if ema_model:
            self.ema_model = ema_model
        self.d_model = self.model.d_model
        self.occ_epochs = args.occ_epochs
        self.start_mil_epoch = 0
        self.mil_epochs = args.mil_epochs
        self.mil_cur_epoch = -1
        self.start_debias_epoch = 0
        self.debias_epochs = args.debias_epochs
        self.debias_cur_epoch = -1
        self.it = 0
        self.num_train_iter = self.mil_epochs * len(self.train_loader_dict['normal_seq'])
        self.num_warmup_iter = self.num_train_iter // 40
        self.debias_it = 0
        self.num_debias_train_iter = self.debias_epochs * len(self.train_loader_dict['normal_seq'])
        self.num_debais_warmup_iter = self.num_debias_train_iter // 40
        self.maxauc = 0
        self.best_it = 0

        self.use_prototypes = args.use_prototypes
        self.forward_passes = args.forward_passes
        self.high_confidence_threshold_rate = args.high_confidence_threshold_rate
        self.mid_confidence_threshold_rate = args.mid_confidence_threshold_rate
        self.binary_thresh = args.binary_thresh
        self.tau_sim = args.tau_sim
        self.tau_c = args.tau_c 
        self.T = args.T
        self.ema_p = args.ema_p
        self.use_quantile = args.use_quantile
        self.clip_thresh = args.clip_thresh 
        self.mid_loss_weight = args.mid_loss_weight
        self.num_classes = args.num_classes
        self.thresholding_mask_generator = AdaptiveThresholdingMaskGenerator(num_classes=self.num_classes, momentum=self.args.ema_p, use_quantile=self.use_quantile, clip_thresh=self.clip_thresh, binary_thresh=self.binary_thresh, tau_c=self.tau_c)

        self.optimizer = self.set_optimizer()
        self.print_fn("Create mil scheduler")
        if args.use_scheduler:
            self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                    self.num_train_iter,
                                                    num_warmup_steps=self.num_warmup_iter)
        else:
            self.scheduler = None
        self.cls_loss = CrossEntropyfn
        self.clip_grad = args.clip_grad

        self.final_score_type = args.final_score_type
        self.select_score_type = args.select_score_type
        self.lambda_cls = args.lambda_cls
        self.lambda_multi_hypersphere = args.lambda_multi_hypersphere
        self.lambda_magnitude = args.lambda_magnitude
        self.sigmoid = nn.Sigmoid()
        self.mil_dropout = args.mil_dropout
        if self.mil_dropout > 0:
            self.drop_out = nn.Dropout(self.mil_dropout)
        else:
            self.drop_out = None
        self.batch_size = args.batch_size
        self.bsdiv2 = args.bsdiv2
        self.miltopk = args.miltopk
        self.smooth_loss_weight = args.smooth_loss_weight
        self.sparse_loss_weight = args.sparse_loss_weight
        self.const_loss_weight = args.const_loss_weight
        self.separateness_loss_weight = args.separateness_loss_weight
        self.neg_feedback_loss_weight = args.neg_feedback_loss_weight
        self.prototypes_loss_weight = args.prototypes_loss_weight
        self.gcpct_loss_weight = args.gcpct_loss_weight
        self.hard_loss_weight = args.hard_loss_weight
        self.all_loss_weight = args.all_loss_weight
        self.MIL_model_save_path = os.path.join(args.root, '{}_{}_{}_{}_{}_level_miltop{}_checkpoint.pth'.format(
            args.srcroot[-4:], args.encoder_type, args.reduction,
            args.classification_level, 'ln' if args.use_layernorm else 'noln', self.miltopk))
        self.OC_model_save_path = os.path.join(args.root, '{}_{}_None_{}_{}_level_occ_checkpoint.pth'.format(
            args.srcroot[-4:], args.encoder_type,
            args.classification_level, 'ln' if args.use_layernorm else 'noln'))
        self.final_model_save_path = os.path.join(args.root, '{}_{}_{}_{}_{}_level_checkpoint_miltop{}_final.pth'.format(
            args.srcroot[-4:], args.encoder_type, args.reduction,
            args.classification_level, 'ln' if args.use_layernorm else 'noln', self.miltopk))

    def load_ema_model(self):
        self.ema_model.load_state_dict(self.model.state_dict())

    def mil_score_generator(self, features, cls_scores, padding_mask, select_score_type='both'):
        normal_features = features[0: self.bsdiv2]  
        normal_cls_scores = cls_scores[0: self.bsdiv2]
        abnormal_features = features[self.bsdiv2:]  
        abnormal_cls_scores = cls_scores[self.bsdiv2:] 
        magnitude_feat_scores = torch.norm(features, p=2, dim=2)
        nmagnitude_feat_scores = magnitude_feat_scores[0: self.bsdiv2]  
        amagnitude_feat_scores = magnitude_feat_scores[self.bsdiv2:]  
        multi_hypersphere_feat_scores = self.model.multi_hyperspheres.top1_score(features, use_detach=self.args.use_detach) 
        nmulti_hypersphere_feat_scores = multi_hypersphere_feat_scores[0: self.bsdiv2] 
        amulti_hypersphere_feat_scores = multi_hypersphere_feat_scores[self.bsdiv2:]  
        feat_scores = self.lambda_cls * self.sigmoid(cls_scores.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(multi_hypersphere_feat_scores) + self.lambda_magnitude * torch.tanh(magnitude_feat_scores)
        nfea_scores = feat_scores[0: self.bsdiv2] 
        afea_scores = feat_scores[self.bsdiv2:]  
        select_idx = padding_mask[self.bsdiv2:, :afea_scores.shape[1]].to(torch.float32) 
        if self.drop_out:
            select_idx = self.drop_out(select_idx)  
        if select_score_type == "cls":
            afea_scores_drop = abnormal_cls_scores.squeeze(-1) * select_idx
        elif select_score_type == 'multi-hypersphere':
            afea_scores_drop = amulti_hypersphere_feat_scores * select_idx 
        elif select_score_type == 'both':
            afea_scores_drop = afea_scores * select_idx
        else:
            raise NotImplementedError
        idx_abn = torch.topk(afea_scores_drop, self.miltopk, dim=1)[1]  
        '''
        例子
        >>> a
        tensor([[ 0.3765, -0.3184, -0.5630, -0.6671, -1.4842],
                [ 0.3050, -1.8240, -0.6383, -0.2340,  0.4645],
                [ 0.1966, -0.2051,  0.2086,  0.9228,  0.0057],
                [ 0.7036, -0.5474,  0.5190,  0.0098,  0.0765]])
        >>> v, i = torch.topk(a, 2, dim=1)
        >>> v
        tensor([[ 0.3765, -0.3184],
                [ 0.4645,  0.3050],
                [ 0.9228,  0.2086],
                [ 0.7036,  0.5190]])
        >>> i
        tensor([[0, 1],
                [4, 0],
                [3, 2],
                [0, 2]])
        '''
        idx_abn_feat = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_features.shape[2]])
        total_select_abn_feature = torch.zeros(0).to(self.device) 
        feat_select_abn = torch.gather(abnormal_features, 1, idx_abn_feat) 
        total_select_abn_feature = torch.cat((total_select_abn_feature, feat_select_abn)) 
        idx_abn_score = idx_abn.unsqueeze(2).expand([-1, -1, abnormal_cls_scores.shape[2]]) 
        cls_score_abnormal = torch.mean(torch.gather(self.sigmoid(abnormal_cls_scores), 1, idx_abn_score), dim=1)
        magnitude_score_abnormal = torch.mean(torch.gather(torch.tanh(amagnitude_feat_scores).unsqueeze(-1), 1, idx_abn_score), dim=1) 
        multi_hypersphere_score_abnormal = torch.mean(torch.gather(torch.tanh(amulti_hypersphere_feat_scores).unsqueeze(-1), 1, idx_abn_score), dim=1) 
        feat_score_abnormal = torch.mean(torch.gather(afea_scores.unsqueeze(-1), 1, idx_abn_score), dim=1) 

        select_idx_normal = padding_mask[:self.bsdiv2, :nfea_scores.shape[1]].to(torch.float32) 
        if self.drop_out:
            select_idx_normal = self.drop_out(select_idx_normal)
        if select_score_type == "cls":
            nfea_scores_drop = normal_cls_scores.squeeze(-1) * select_idx_normal 
        elif select_score_type == 'multi-hypersphere':
            nfea_scores_drop = nmulti_hypersphere_feat_scores * select_idx_normal 
        elif select_score_type == 'both':
            nfea_scores_drop = nfea_scores * select_idx_normal
        else:
            raise NotImplementedError
        idx_normal = torch.topk(nfea_scores_drop, self.miltopk, dim=1)[1]
        idx_normal_feat = idx_normal.unsqueeze(2).expand([-1, -1, normal_features.shape[2]]) 
        total_select_nor_feature = torch.zeros(0).to(self.device) 
        feat_select_normal = torch.gather(normal_features, 1, idx_normal_feat) 
        total_select_nor_feature = torch.cat((total_select_nor_feature, feat_select_normal)) 
        idx_normal_score = idx_normal.unsqueeze(2).expand([-1, -1, normal_cls_scores.shape[2]]) 
        cls_score_normal = torch.mean(torch.gather(self.sigmoid(normal_cls_scores), 1, idx_normal_score), dim=1) 
        magnitude_score_normal = torch.mean(torch.gather(torch.tanh(nmagnitude_feat_scores).unsqueeze(-1), 1, idx_normal_score), dim=1)
        multi_hypersphere_score_normal = torch.mean(torch.gather(torch.tanh(nmulti_hypersphere_feat_scores).unsqueeze(-1), 1, idx_normal_score), dim=1)
        feat_score_normal = torch.mean(torch.gather(nfea_scores.unsqueeze(-1), 1, idx_normal_score), dim=1) 
        select_abn_feature = total_select_abn_feature  
        select_nor_feature = total_select_nor_feature  
        output = {}
        output['cls_score_abnormal'] = cls_score_abnormal 
        output['magnitude_score_abnormal'] = magnitude_score_abnormal 
        output['multi_hypersphere_score_abnormal'] = multi_hypersphere_score_abnormal 
        output['feat_score_abnormal'] = feat_score_abnormal 
        output['cls_score_normal'] = cls_score_normal 
        output['magnitude_score_normal'] = magnitude_score_normal 
        output['multi_hypersphere_score_normal'] = multi_hypersphere_score_normal
        output['feat_score_normal'] = feat_score_normal 
        output['select_abn_feature'] = select_abn_feature 
        output['select_nor_feature'] = select_nor_feature 
        output['cls_score'] = self.sigmoid(cls_scores.squeeze(-1)) * padding_mask[:, :cls_scores.shape[1]].to(torch.float32)  
        output['magnitude_score'] = torch.tanh(magnitude_feat_scores) * padding_mask[:, :magnitude_feat_scores.shape[1]].to(torch.float32)
        output['multi_hypersphere_score'] = torch.tanh(multi_hypersphere_feat_scores) * padding_mask[:, :multi_hypersphere_feat_scores.shape[1]].to(torch.float32)
        output['feat_score'] = feat_scores * padding_mask[:, :feat_scores.shape[1]].to(torch.float32) 
        return output

    def train_step_mil(self, ndata, adata, select_score_type):
        nseq = ndata['hist_activity'] 
        aseq = adata['hist_activity'] 
        x = torch.cat((nseq, aseq), 0) 
        seqs_length = torch.cat((ndata['max_len'], adata['max_len']), 0) 
        model_output = self.model(x, seqs_length)
        feature = model_output['feature'] 
        logits = model_output['logits']  
        cls_scores = logits 
        padding_mask = generate_padding_mask(x.shape[1], seqs_length, self.device) 
        output_mil_select = self.mil_score_generator(feature, cls_scores, padding_mask, select_score_type=select_score_type)
        cls_score_abnormal = output_mil_select['cls_score_abnormal'] 
        magnitude_score_abnormal = output_mil_select['magnitude_score_abnormal'] 
        multi_hypersphere_score_abnormal = output_mil_select['multi_hypersphere_score_abnormal']  
        feat_score_abnormal = output_mil_select['feat_score_abnormal']
        cls_score_normal = output_mil_select['cls_score_normal'] 
        magnitude_score_normal = output_mil_select['magnitude_score_normal'] 
        multi_hypersphere_score_normal = output_mil_select['multi_hypersphere_score_normal']  
        feat_score_normal = output_mil_select['feat_score_normal']  
        select_abn_feature = output_mil_select['select_abn_feature'] 
        select_nor_feature = output_mil_select['select_nor_feature'] 
        cls_score = output_mil_select['magnitude_score'] 
        magnitude_score = output_mil_select['magnitude_score'] 
        multi_hypersphere_score = output_mil_select['multi_hypersphere_score'] 
        feat_score = output_mil_select['feat_score'] 
        if self.args.final_score_type == 'cls':
            scores = cls_score  
            scores_normal = cls_score_normal 
            scores_abnormal = cls_score_abnormal  
        elif self.args.final_score_type == 'multi-hypersphere':
            scores = multi_hypersphere_score  
            scores_normal = multi_hypersphere_score_normal  
            scores_abnormal = multi_hypersphere_score_abnormal  
        elif self.args.final_score_type == 'both':
            scores = feat_score  
            scores_normal = feat_score_normal 
            scores_abnormal = feat_score_abnormal
        else:
            raise NotImplementedError
        nor_scores = scores[:self.bsdiv2, :]  
        abn_scores = scores[self.bsdiv2:, :]  
        loss_smooth = smooth(nor_scores, ndata['max_len'], self.smooth_loss_weight)  
        loss_sparse = sparsity(abn_scores, adata['max_len'], self.sparse_loss_weight)
        nlabel = ndata['session_label']
        alabel = adata['session_label']
        loss_criterion = MILloss(self.const_loss_weight)
        output_mil_loss = loss_criterion(scores_normal, scores_abnormal, nlabel, alabel, select_nor_feature, select_abn_feature) 
        mil_loss = output_mil_loss['total_loss'] 
        cost = mil_loss + loss_smooth + loss_sparse
        if self.use_prototypes:
            normal_features = feature[:self.bsdiv2][padding_mask[:self.bsdiv2, :feature.shape[1]]] 
            compactness_loss = self.model.multi_hyperspheres.compactness_loss(normal_features, use_detach=self.args.use_detach)
            if self.args.num_spheres > 1:
                separateness_loss = self.model.multi_hyperspheres.separateness_loss(normal_features, use_detach=self.args.use_detach)
            else:
                separateness_loss = torch.tensor(0.0).to(self.device)
            neg_feedback_loss = self.model.multi_hyperspheres.neg_feedback_loss(select_abn_feature, use_detach=self.args.use_detach)
            prototypes_loss = compactness_loss + self.separateness_loss_weight * separateness_loss + self.neg_feedback_loss_weight * neg_feedback_loss
            cost += self.prototypes_loss_weight * prototypes_loss
        output = {}
        output['total_loss'] = cost 
        output['MIL_loss'] = mil_loss
        output['cls_loss'] = output_mil_loss['cls_loss']
        output['loss_con'] = output_mil_loss['loss_con']
        output['loss_con_a'] = output_mil_loss['loss_con_a']
        output['loss_con_n'] = output_mil_loss['loss_con_n']
        output['contrastive_loss'] = output_mil_loss['contrastive_loss']
        output['MIL_smooth_loss'] = loss_smooth
        output['MIL_sparse_loss'] = loss_sparse
        output['compactness_loss'] = compactness_loss
        output['separateness_loss'] = self.separateness_loss_weight * separateness_loss
        output['neg_feedback_loss'] = self.neg_feedback_loss_weight * neg_feedback_loss
        output['prototypes_loss'] = self.prototypes_loss_weight * prototypes_loss
        output['action_predictions'] = scores[padding_mask[:, :scores.shape[1]]] 
        output['action_labels'] = torch.cat([ndata['acts_labels'], adata['acts_labels']], 0)[padding_mask] 
        return output

    def eval_step(self, batch):
        x = batch['hist_activity'] 
        seqs_length = batch['max_len']
        model_output = self.model(x, seqs_length)
        features = model_output['feature'] 
        logits = model_output['logits'] 
        cls_scores = logits  
        padding_mask = generate_padding_mask(x.shape[1], seqs_length, self.device) 
        multi_hypersphere_feat_scores = self.model.multi_hyperspheres.top1_score(features) 
        magnitude_feat_scores = torch.norm(features, p=2, dim=2)
        feat_scores = self.lambda_cls * self.sigmoid(
            cls_scores.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(
            multi_hypersphere_feat_scores) + self.lambda_magnitude * torch.tanh(magnitude_feat_scores)
        if self.args.final_score_type == 'cls': 
            scores = self.sigmoid(cls_scores.squeeze(-1)) 
        elif self.args.final_score_type == 'multi-hypersphere':
            scores = torch.tanh(multi_hypersphere_feat_scores)  
        elif self.args.final_score_type == 'both':
            scores = feat_scores 
        else:
            raise NotImplementedError
        output = {}
        output['action_predictions'] = scores[padding_mask[:, :scores.shape[1]]] 
        output['action_labels'] = batch['acts_labels'][padding_mask] 
        return output

    def train_mil(self):
        if self.continuetrain_MIL == True:
            assert os.path.exists(self.MIL_model_save_path)
            self.print_fn('continuetrain_debias, load {}'.format(self.MIL_model_save_path))
            checkpoint = torch.load(self.MIL_model_save_path)
            self.model.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if self.scheduler is not None:
                self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_mil_epoch = checkpoint['epoch'] + 1
            self.it = checkpoint['it']
            self.maxauc = checkpoint['auc']
            self.binary_thresh = checkpoint['threshold']
        es = 0
        for epoch in range(self.start_mil_epoch, self.mil_epochs):
            self.print_fn('Naive MIL training Epoch:{}.....'.format(epoch))
            self.mil_cur_epoch = epoch
            if self.it >= self.num_train_iter:
                break
            y_true = []
            y_pred = []
            total_loss = 0
            num_batchs = len(self.train_loader_dict['normal_seq'])
            t = tqdm(zip(self.train_loader_dict['normal_seq'], self.train_loader_dict['abnormal_seq']), desc='Naive MIL Training, epoch:{}'.format(epoch), total=num_batchs)
            for step, (ndata, adata) in enumerate(t):
                if step % 1000 == 0:
                    self.eval_func(epoch, 'mil')
                self.model.train()
                self.ema_model.train()
                if self.it >= self.num_train_iter: 
                    break
                max_len = max(adata['hist_activity'].shape[1], ndata['hist_activity'].shape[1])
                ndata = self.collate_fn(ndata, max_len)
                adata = self.collate_fn(adata, max_len)
                for k, v in ndata.items():
                    ndata[k] = ndata[k].to(self.device)
                for k, v in adata.items():
                    adata[k] = adata[k].to(self.device)
                if self.args.use_hypersphere_warm_up:
                    if self.it <= self.args.rate_warm_up * num_batchs:
                        output = self.train_step_mil(ndata, adata, select_score_type='multi-hypersphere')
                    else:
                        output = self.train_step_mil(ndata, adata, select_score_type='both')
                else:
                    output = self.train_step_mil(ndata, adata, select_score_type=self.select_score_type)
                loss = output['total_loss'] 
                MIL_loss = output['MIL_loss']
                loss_con = output['loss_con']
                loss_con_a = output['loss_con_a']
                loss_con_n = output['loss_con_n']
                loss_contrastive = output['contrastive_loss'] 
                compactness_loss = output['compactness_loss']
                separateness_loss = output['separateness_loss']
                neg_feedback_loss = output['neg_feedback_loss']
                smooth_loss = output['MIL_smooth_loss']
                sparse_loss = output['MIL_sparse_loss']
                prototypes_loss = output['prototypes_loss'] 
                total_loss += loss.item()
                loss.backward()
                if (self.clip_grad > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                if self.it % 50 == 0:
                    if (self.it // 50) % 2 == 0:
                        t.set_postfix(loss=loss.item(), cls=output['cls_loss'].item(), pro=prototypes_loss.item(), smo=smooth_loss.item(), spa=sparse_loss.item())
                    else:
                        t.set_postfix(compactness_loss=compactness_loss.item(),
                                      separateness_loss=separateness_loss.item(),
                                      neg_feedback_loss=neg_feedback_loss.item())
                self.optimizer.zero_grad()
                y_true.append(output['action_labels'].detach().cpu().numpy())
                y_pred.append(output['action_predictions'].detach().cpu().numpy())
                self.it += 1
            y_true_all = np.hstack(y_true)
            y_pred_all = np.hstack(y_pred)
            auc = roc_auc_score(y_true_all, y_pred_all)
            fpr, tpr, threshold = bestthreshold_with_ROC(y_true_all, y_pred_all)
            self.print_fn("Train Epoch:{}\ntrain_loss:{}\nauc:{}\nfpr:{}\ntpr:{}\nthreshold:{}".format(epoch, total_loss / num_batchs, auc, fpr, tpr, threshold))
            self.print_fn("\n")
            self.print_fn('Naive MIL validating Epoch:{}.....'.format(epoch))
            self.model.eval()
            self.ema_model.eval()
            y_true = []
            y_pred = []
            num_batchs = len(self.val_loader)
            t = tqdm(self.val_loader, desc='Naive MIL validating, epoch:{}'.format(epoch), total=num_batchs)
            for step, batch in enumerate(t):
                max_len = batch['hist_activity'].shape[1]
                batch = self.collate_fn(batch, max_len)
                for k, v in batch.items():
                    batch[k] = batch[k].to(self.device)
                with torch.no_grad():
                    output = self.eval_step(batch)
                y_true.append(output['action_labels'].detach().cpu().numpy())
                y_pred.append(output['action_predictions'].detach().cpu().numpy())
            y_true_all = np.hstack(y_true)
            y_pred_all = np.hstack(y_pred)
            auc = roc_auc_score(y_true_all, y_pred_all)
            fpr, tpr, threshold = bestthreshold_with_ROC(y_true_all, y_pred_all)
            self.print_fn("Eval Epoch:{}\nauc:{}\nfpr:{}\ntpr:{}\nthreshold:{}".format(epoch, auc, fpr, tpr, threshold))
            if self.args.use_nni:
                nni.report_intermediate_result(auc)
            if auc > self.maxauc:
                self.maxauc = auc
                self.binary_thresh = threshold
                state = {
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
                    'epoch': epoch,
                    'it': self.it,
                    'auc': auc,
                    'fpr': fpr,
                    'tpr': tpr,
                    'threshold': threshold,
                    'y_true_all': y_true_all,
                    'y_pred_all': y_pred_all
                }
                es = 0
                torch.save(state, self.MIL_model_save_path)
            else:
                es += 1
                logger.info("Counter {} of 5".format(es))
                if es > 3:
                    logger.info("Early stopping with best_auc: {}".format(self.maxauc))
                    break
        if self.args.use_nni:
            nni.report_final_result(auc)
        state = torch.load(self.MIL_model_save_path)
        self.print_fn("----------------------------------------------------------------------------------------------------------------------")
        self.print_fn("best Epoch:{}\nAuc:{}\nFPR:{}\nTPR:{}".format(state['epoch'], state['auc'],
                                                                            state['fpr'], state['tpr']))

    def train_step_occ(self, ndata):
        nseq = ndata['hist_activity'] 
        x = nseq 
        seqs_length = ndata['max_len'] 
        model_output = self.model(x, seqs_length)
        feature = model_output['feature'] 
        padding_mask = generate_padding_mask(x.shape[1], seqs_length, self.device) 
        scores = self.model.multi_hyperspheres.top1_score(feature, use_detach=args.use_detach)  
        assert self.use_prototypes
        compactness_loss = self.model.multi_hyperspheres.compactness_loss(feature[padding_mask[:, :feature.shape[1]]], use_detach=args.use_detach)
        if self.args.num_spheres > 1:
            separateness_loss = self.model.multi_hyperspheres.separateness_loss(feature[padding_mask[:, :feature.shape[1]]], use_detach=args.use_detach)
        else:
            separateness_loss = torch.tensor(0.0).to(self.device)
        gcpct_loss = self.model.multi_hyperspheres.gcpct_loss(feature[padding_mask[:, :feature.shape[1]]])
        prototypes_loss = (1-self.separateness_loss_weight) * compactness_loss + self.separateness_loss_weight * separateness_loss + self.gcpct_loss_weight * gcpct_loss
        cost = prototypes_loss
        output = {}
        output['total_loss'] = cost 
        output['compactness_loss'] = compactness_loss
        output['separateness_loss'] = self.separateness_loss_weight * separateness_loss
        output['gcpct_loss'] = self.gcpct_loss_weight * gcpct_loss
        output['action_predictions'] = scores[padding_mask[:, :scores.shape[1]]] 
        output['action_labels'] = ndata['acts_labels'][padding_mask] 
        output['scores'] = scores
        output['feature'] = feature
        return output

    def eval_step_occ(self, batch):
        x = batch['hist_activity'] 
        seqs_length = batch['max_len'] 
        model_output = self.model(x, seqs_length)
        features = model_output['feature']
        logits = model_output['logits'] 
        cls_scores = logits  
        padding_mask = generate_padding_mask(x.shape[1], seqs_length, self.device) 
        multi_hypersphere_feat_scores = self.model.multi_hyperspheres.top1_score(features)
        magnitude_feat_scores = torch.norm(features, p=2, dim=2)
        feat_scores = self.lambda_cls * self.sigmoid(
            cls_scores.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(
            multi_hypersphere_feat_scores) + self.lambda_magnitude * torch.tanh(magnitude_feat_scores)
        scores = torch.tanh(multi_hypersphere_feat_scores)  
        output = {}
        output['action_predictions'] = scores[padding_mask[:, :scores.shape[1]]]
        output['action_labels'] = batch['acts_labels'][padding_mask] 
        return output

    def train_occ(self):
        it = 0
        for epoch in range(self.occ_epochs):
            self.print_fn('multi-hypersphere occ training Epoch:{}.....'.format(epoch))
            self.model.train()
            y_true = []
            y_pred = []
            total_loss = 0
            num_batchs = len(self.train_loader_dict['normal_seq'])
            t = tqdm(self.train_loader_dict['normal_seq'], desc='multi-hypersphere occ Training, epoch:{}'.format(epoch), total=num_batchs)
            for step, ndata in enumerate(t):
                max_len = ndata['hist_activity'].shape[1]
                ndata = self.collate_fn(ndata, max_len)
                for k, v in ndata.items():
                    ndata[k] = ndata[k].to(self.device)
                output = self.train_step_occ(ndata)
                loss = output['total_loss'] 
                compactness_loss = output['compactness_loss']
                separateness_loss = output['separateness_loss']
                gcpct_loss = output['gcpct_loss']
                total_loss += loss.item()
                loss.backward()
                if (self.clip_grad > 0):
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                if it % 50 == 0:
                    t.set_postfix(loss=loss.item(), compact_loss=compactness_loss.item(), sepa_loss=separateness_loss.item(), gcpct_loss=gcpct_loss.item())
                if it % 1500 == 0:
                    print("hyper:")
                    print(self.model.multi_hyperspheres.prototypes_centers.data)
                    print("feature[0]:")
                    print(output['feature'][0])
                    print("score:")
                    print(output['scores'])
                self.optimizer.zero_grad()
                it += 1
            self.print_fn("Train Epoch:{}\ntrain_loss:{}".format(epoch, total_loss / num_batchs))
            self.print_fn("\n")
            self.print_fn('multi-hypersphere occ validating Epoch:{}.....'.format(epoch))
            self.model.eval()
            y_true = []
            y_pred = []
            num_batchs = len(self.val_loader)
            t = tqdm(self.val_loader, desc='multi-hypersphere occ validating, epoch:{}'.format(epoch), total=num_batchs)
            for step, batch in enumerate(t):
                max_len = batch['hist_activity'].shape[1]
                batch = self.collate_fn(batch, max_len)
                for k, v in batch.items():
                    batch[k] = batch[k].to(self.device)
                with torch.no_grad():
                    output = self.eval_step_occ(batch)
                y_true.append(output['action_labels'].detach().cpu().numpy())
                y_pred.append(output['action_predictions'].detach().cpu().numpy())
            y_true_all = np.hstack(y_true)
            y_pred_all = np.hstack(y_pred)
            auc = roc_auc_score(y_true_all, y_pred_all)
            fpr, tpr, threshold = bestthreshold_with_ROC(y_true_all, y_pred_all)
            self.print_fn("Eval Epoch:{}\nauc:{}\nfpr:{}\ntpr:{}\nthreshold:{}".format(epoch, auc, fpr, tpr, threshold))
            if self.args.use_nni:
                nni.report_intermediate_result(auc)
            if auc > self.maxauc:
                self.maxauc = auc
                state = {
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
                    'epoch': epoch,
                    'it': self.it,
                    'auc': auc,
                    'fpr': fpr,
                    'tpr': tpr,
                    'threshold': threshold,
                    'y_true_all': y_true_all,
                    'y_pred_all': y_pred_all
                }
                es = 0
                torch.save(state, self.OC_model_save_path)
            else:
                es += 1
                logger.info("Counter {} of 5".format(es))
                if es > 3:
                    logger.info("Early stopping with best_auc: {}".format(self.maxauc))
                    break
        if self.args.use_nni:
            nni.report_final_result(auc)

    def train_step_adaptive_self_training(self, ndata, adata, high_confidence_mask, high_pseudo_label, mid_confidence_mask, preds_mean, uncertain_e, padding_mask_n, padding_mask_a):
        model_output_n = self.model(ndata['hist_activity'], ndata['max_len'])
        logits_n = model_output_n['logits']  
        features_n = model_output_n['feature'] 
        multi_hypersphere_feat_scores_n = self.model.multi_hyperspheres.top1_score(features_n, use_detach=self.args.use_detach) 
        magnitude_feat_scores_n = torch.norm(features_n, p=2, dim=2)
        feat_scores_n = self.lambda_cls * self.sigmoid(logits_n.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(multi_hypersphere_feat_scores_n) + self.lambda_magnitude * torch.tanh(magnitude_feat_scores_n)
        if self.args.final_score_type == 'cls':
            scores_n = self.sigmoid(logits_n.squeeze(-1)) 
        elif self.args.final_score_type == 'multi-hypersphere':
            scores_n = torch.tanh(multi_hypersphere_feat_scores_n)  
        elif self.args.final_score_type == 'both':
            scores_n = feat_scores_n  
        else:
            raise NotImplementedError
        sup_loss = self.cls_loss(scores_n, torch.zeros_like(scores_n).to(self.device), reduction='none', type='BCE', use_sigmoid=False)
        sup_loss = sup_loss * padding_mask_n[:, :scores_n.shape[1]] 
        sup_loss = sup_loss.sum() / (padding_mask_n[:, :scores_n.shape[1]].sum()) 
        model_output_a = self.model(adata['hist_activity'], adata['max_len'])
        logits_a = model_output_a['logits']  
        features_a = model_output_a['feature'] 
        multi_hypersphere_feat_scores_a = self.model.multi_hyperspheres.top1_score(features_a, use_detach=self.args.use_detach)
        magnitude_feat_scores_a = torch.norm(features_a, p=2, dim=2)
        feat_scores_a = self.lambda_cls * self.sigmoid(logits_a.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(multi_hypersphere_feat_scores_a) + self.lambda_magnitude * torch.tanh(magnitude_feat_scores_a)
        if self.args.final_score_type == 'cls': 
            scores_a = self.sigmoid(logits_a.squeeze(-1)) 
        elif self.args.final_score_type == 'multi-hypersphere':
            scores_a = torch.tanh(multi_hypersphere_feat_scores_a)  
        elif self.args.final_score_type == 'both':
            scores_a = feat_scores_a 
        else:
            raise NotImplementedError
        sup_loss_a = self.cls_loss(scores_a, high_pseudo_label, reduction='none', type='BCE', use_sigmoid=False) * padding_mask_a[:, :scores_a.shape[1]] * high_confidence_mask[:, :scores_a.shape[1]]  
        sup_loss_a = sup_loss_a.sum() / (padding_mask_a[:, :scores_a.shape[1]] * high_confidence_mask[:, :scores_a.shape[1]]).sum() 
        sup_loss += sup_loss_a
        feat_dict = {'feats_n': features_n, 'feats_a': features_a}
        adaptive_thresh_mask = self.thresholding_mask_generator.masking(preds_mean, mid_confidence_mask) 
        adaptive_thresh_mask = adaptive_thresh_mask * mid_confidence_mask
        hard_pseudo_label = self.thresholding_mask_generator.gen_pseudo_labels(preds_mean) 
        mid_loss = self.cls_loss(scores_a, hard_pseudo_label, reduction='none', type='BCE', use_sigmoid=False)
        if (padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]] * adaptive_thresh_mask[:, :scores_a.shape[1]]).sum() > 0:
            print("sum={},".format((padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]] * adaptive_thresh_mask[:, :scores_a.shape[1]]).sum().item()))
            mid_loss = mid_loss * padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]] * adaptive_thresh_mask[:, :scores_a.shape[1]]
            mid_loss = mid_loss.sum() / (padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]] * adaptive_thresh_mask[:, :scores_a.shape[1]]).sum()
        else:
            print("mid_loss divide 0 ")
        ema_model_output_a = self.ema_model(adata['hist_activity'], adata['max_len'])
        ema_logits_a = ema_model_output_a['logits']  
        ema_features_a = ema_model_output_a['feature']  
        ema_multi_hypersphere_feat_scores_a = self.model.multi_hyperspheres.top1_score(ema_features_a, use_detach=self.args.use_detach) 
        ema_magnitude_feat_scores_a = torch.norm(ema_features_a, p=2, dim=2)
        ema_feat_scores_a = self.lambda_cls * self.sigmoid(ema_logits_a.squeeze(-1)) + self.lambda_multi_hypersphere * torch.tanh(ema_multi_hypersphere_feat_scores_a) + self.lambda_magnitude * torch.tanh(ema_magnitude_feat_scores_a)
        if self.args.final_score_type == 'cls': 
            ema_scores_a = self.sigmoid(ema_logits_a.squeeze(-1)) 
        elif self.args.final_score_type == 'multi-hypersphere':
            ema_scores_a = torch.tanh(ema_multi_hypersphere_feat_scores_a) 
        elif self.args.final_score_type == 'both':
            ema_scores_a = ema_feat_scores_a 
        else:
            raise NotImplementedError
        soft_loss = self.cls_loss(scores_a, ema_scores_a, reduction='none', type='BCE', use_sigmoid=False) * padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]]
        soft_loss = soft_loss.sum() / (padding_mask_a[:, :scores_a.shape[1]] * mid_confidence_mask[:, :scores_a.shape[1]]).sum()
        loss_smooth = smooth(scores_a, adata['max_len'], self.smooth_loss_weight)  
        loss_sparse = sparsity(scores_a, adata['max_len'], self.sparse_loss_weight) 
        flatten_features_a = features_a[padding_mask_a[:, :features_a.shape[1]]] 
        cos_feature_sim = F.cosine_similarity(flatten_features_a.unsqueeze(1), flatten_features_a.unsqueeze(0), dim=-1)
        tri_mask = torch.triu(torch.ones_like(cos_feature_sim), diagonal=1).to(self.device) 
        cos_feature_sim = cos_feature_sim * tri_mask
        sim_mask = cos_feature_sim > self.tau_sim
        scores_dist = torch.cat([(1-scores_a).unsqueeze(-1), scores_a.unsqueeze(-1)], dim=-1) 
        flatten_scores_a = scores_dist[padding_mask_a[:, :scores_dist.shape[1]]] 
        cos_scores_sim = F.cosine_similarity(flatten_scores_a.unsqueeze(1), flatten_scores_a.unsqueeze(0), dim=-1) 
        cos_scores_sim = cos_scores_sim * tri_mask
        all_loss = self.cls_loss(torch.clamp(cos_scores_sim, max=1.0), torch.ones_like(cos_scores_sim).to(self.device), reduction='none', type='BCE', use_sigmoid=False)
        all_loss = all_loss * sim_mask.to(torch.int64)
        all_loss = all_loss.sum() / sim_mask.sum()
        total_loss = sup_loss + self.mid_loss_weight * (self.hard_loss_weight * mid_loss + (1 - self.hard_loss_weight) * soft_loss) + self.all_loss_weight * all_loss + loss_smooth + loss_sparse
        if self.use_prototypes:
            normal_features_lb = features_n[padding_mask_n[:, :features_n.shape[1]]] 
            normal_features_lb = torch.cat([normal_features_lb, features_a[high_confidence_mask[:, :features_a.shape[1]].to(torch.bool)][~high_pseudo_label[high_confidence_mask[:, :features_a.shape[1]].to(torch.bool)].to(torch.bool)]], dim=0)
            adaptive_thresh_mask_padding = adaptive_thresh_mask[padding_mask_a[:, :adaptive_thresh_mask.shape[1]]].to(torch.bool) 
            normal_features_ulb_padding = features_a[padding_mask_a[:, :features_a.shape[1]]]
            features_ulb = normal_features_ulb_padding[adaptive_thresh_mask_padding]
            hard_pseudo_label_padding = hard_pseudo_label[padding_mask_a[:, :hard_pseudo_label.shape[1]]]
            hard_pseudo_label_ulb = hard_pseudo_label_padding[adaptive_thresh_mask_padding].to(torch.bool) 
            normal_features_ulb = features_ulb[~hard_pseudo_label_ulb] 
            normal_features = torch.cat([normal_features_lb, normal_features_ulb])  
            abnormal_features = features_ulb[hard_pseudo_label_ulb] 
            abnormal_features = torch.cat([abnormal_features, features_a[high_confidence_mask[:, :features_a.shape[1]].to(torch.bool)][high_pseudo_label[high_confidence_mask[:, :features_a.shape[1]].to(torch.bool)].to(torch.bool)]], dim=0)
            compactness_loss = self.model.multi_hyperspheres.compactness_loss(normal_features, use_detach=self.args.use_detach)
            if self.args.num_spheres > 1:
                separateness_loss = self.model.multi_hyperspheres.separateness_loss(normal_features, use_detach=self.args.use_detach)
            else:
                separateness_loss = torch.tensor(0.0).to(self.device)
            neg_feedback_loss = self.model.multi_hyperspheres.neg_feedback_loss(abnormal_features, use_detach=self.args.use_detach)
            prototypes_loss = compactness_loss + self.separateness_loss_weight * separateness_loss + self.neg_feedback_loss_weight * neg_feedback_loss
            total_loss += self.prototypes_loss_weight * prototypes_loss
        output = {}
        output["total_loss"] = total_loss
        output["feat"] = feat_dict
        output["sup_loss"] = sup_loss
        output["mid_loss"] = self.mid_loss_weight * (self.hard_loss_weight * mid_loss + (1-self.hard_loss_weight) * soft_loss)
        output["all_loss"] = self.all_loss_weight * all_loss
        output["prototypes_loss"] = self.prototypes_loss_weight * prototypes_loss
        output["smooth_loss"] = loss_smooth
        output["sparse_loss"] = loss_sparse
        output["hard_loss"] = self.hard_loss_weight * mid_loss
        output["soft_loss"] = (1 - self.hard_loss_weight) * soft_loss
        output["compactness_loss"] = compactness_loss
        output["separateness_loss"] = self.separateness_loss_weight * separateness_loss
        output["neg_feedback_loss"] = self.neg_feedback_loss_weight * neg_feedback_loss
        output["util_ratio"] = adaptive_thresh_mask.float().mean()
        output['action_predictions'] = torch.cat([scores_n[padding_mask_n[:, :scores_n.shape[1]]], scores_a[padding_mask_a[:, :scores_a.shape[1]]]], 0)
        output['action_labels'] = torch.cat([ndata['acts_labels'][padding_mask_n], adata['acts_labels'][padding_mask_a]], 0) 
        return output

    def enable_dropout(self):  
        for m in self.model.modules():
            if type(m) == nn.Dropout:
                m.train()
        for m in self.ema_model.modules():
            if type(m) == nn.Dropout:
                m.train()


    def eval_func(self, epoch, stage, mode='all'):
        if stage == 'mil':
            self.print_fn('{} validating it:{}, epoch:{}.....'.format(stage, self.it, epoch))
        elif stage == 'debias':
            self.print_fn('{} validating it:{}, epoch:{}.....'.format(stage, self.debias_it, epoch))
        else:
            raise NotImplementedError
        self.model.eval()
        self.ema_model.eval()
        y_true = []
        y_pred = []

        if mode == 'all':
            num_batchs = len(self.val_loader)
            if stage == 'mil':
                t = tqdm(self.val_loader, desc='{} all validating, it:{}'.format(stage, self.it),
                         total=num_batchs, position=0, leave=True)
            elif stage == 'debias':
                t = tqdm(self.val_loader, desc='{} all validating, it:{}'.format(stage, self.debias_it),
                         total=num_batchs, position=0, leave=True)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        for step, batch in enumerate(t):
            max_len = batch['hist_activity'].shape[1]
            batch = self.collate_fn(batch, max_len)
            for k, v in batch.items():
                batch[k] = batch[k].to(self.device)
            with torch.no_grad():
                output = self.eval_step(batch)
            y_true.append(output['action_labels'].detach().cpu().numpy())
            y_pred.append(output['action_predictions'].detach().cpu().numpy())
        y_true_all = np.hstack(y_true)
        y_pred_all = np.hstack(y_pred)
        auc = roc_auc_score(y_true_all, y_pred_all)
        fpr, tpr, threshold = bestthreshold_with_ROC(y_true_all, y_pred_all)
        self.print_fn("{} Eval it:{}, epoch:{}\nauc:{}\nfpr:{}\ntpr:{}\nthreshold:{}".format(mode, self.debias_it, epoch, auc, fpr, tpr, threshold))
        if auc > self.maxauc:
            self.maxauc = auc
            self.binary_thresh = threshold
            state = {
                'model': self.model.state_dict(),
                'ema_model': self.ema_model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
                'epoch': epoch,
                'it': self.debias_it,
                'auc': auc,
                'fpr': fpr,
                'tpr': tpr,
                'threshold': threshold,
                'y_true_all': y_true_all,
                'y_pred_all': y_pred_all
            }
            self.es = 0
            if stage == 'occ':
                torch.save(state, self.OC_model_save_path)
            elif stage == 'mil':
                torch.save(state, self.MIL_model_save_path)
            elif stage == 'debias':
                torch.save(state, self.final_model_save_path)
            else:
                raise NotImplementedError
        else:
            self.es += 1
            logger.info("Counter {} of 5".format(self.es))
            if self.es > 5:
                logger.info("Early stopping with best_auc: {}".format(self.maxauc))
                sys.exit()

    def update_ema_model(self):
        alpha = self.args.ema_alpha
        for ema_param, param in zip(self.ema_model.named_parameters(), self.model.named_parameters()):
            ema_param[-1].data.mul_(alpha).add_(param[-1].data, alpha=1 - alpha)

    def collate_fn(self, batch, maxlen=None, has_al=True, flatten=False):
        if flatten:
            for k, v in batch.items():
                if batch[k].dim() == 2:
                    batch[k] = batch[k].reshape(-1)
                else:
                    batch[k] = batch[k].reshape(-1, batch[k].shape[-1])
        if maxlen == None:
            maxlen = max(batch['max_len']).item()
        bs, len = batch['hist_activity'].shape
        padded_batch = torch.full((bs, maxlen), 0).to(self.device)
        padded_batch[:, :len] = batch['hist_activity']
        batch['hist_activity'] = padded_batch
        if has_al:
            padded_batch = torch.full((bs, maxlen), 0).to(self.device)
            padded_batch[:, :len] = batch['acts_labels']
            batch['acts_labels'] = padded_batch
        return batch

    def get_save_dict(self):
        raise NotImplementedError

    def load_model(self, load_path):
        raise NotImplementedError

    def set_optimizer(self):
        """
        set optimizer for algorithm
        """
        self.print_fn("Create optimizer")
        assert self.args.layer_decay <= 1.0
        if self.args.layer_decay != 1.0:
            per_param_args = param_groups_layer_decay(self.model, self.args.lr, self.args.weight_decay,
                                                      layer_decay=self.args.layer_decay)
        else:
            per_param_args = param_groups_weight_decay(self.model, self.args.weight_decay)
        if self.args.optim == 'SGD':
            optimizer = torch.optim.SGD(per_param_args, lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay,
                                        nesterov=True)
        elif self.args.optim == 'AdamW':
            optimizer = torch.optim.AdamW(per_param_args, lr=self.args.lr, weight_decay=self.args.weight_decay)
        return optimizer

    def set_debais_scheduler(self):
        self.print_fn("Create debais scheduler")
        if self.args.use_scheduler:
            self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
                                                         self.num_debias_train_iter,
                                                         num_warmup_steps=self.num_debais_warmup_iter)
        else:
            self.scheduler = None

def ArgumentParser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--srcroot', default='data/r4.2', type=str, help='data/r4.2,data/r5.2')
    parser.add_argument('--root', default='data', type=str)
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--lr', type=float, default=5e-6)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=0.0005)
    parser.add_argument('--layer_decay', type=float, default=1) 
    parser.add_argument('--encoder_type', default='gru', type=str, help='lstm/gru/transformer')
    parser.add_argument('--seq_embedding_size', type=int, default=128)
    parser.add_argument('--hidden_size', type=int, default=128)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--model_dropout', type=float, default=0.3, help='model dropout rate')
    parser.add_argument('--use_layernorm', default=False, action='store_true')
    parser.add_argument('--reduction', default='lastpositionattention', type=str,
                        help='None/avgpooling/selfattention+avgpooling/lastpositionattention/clsattention/attentionnetpooling')
    parser.add_argument('--device', type=str, default='cuda:0', help='cuda:0/cuda:1/cpu')
    parser.add_argument('--UnderSampling', type=float, default=-1, help='UnderSampling rate')
    parser.add_argument('--OverSampling', type=float, default=1, help='OverSampling rate')
    parser.add_argument('--stage', default='mil', type=str, help='occ/mil/debias')
    parser.add_argument("--use_prototypes", default=True)  
    parser.add_argument("--use_hypersphere_warm_up", default=False, action='store_true')  
    parser.add_argument("--init_prototypes_with_clustering", default=False, action='store_true')  
    parser.add_argument("--similarity_type", type=str, default='l2', help='l2/inner_product/cosine')  
    parser.add_argument("--num_spheres", type=int, default=50)  
    parser.add_argument('--occ_epochs', type=int, default=2)
    parser.add_argument('--mil_epochs', type=int, default=5) 
    parser.add_argument('--debias_epochs', default=1, type=int)
    parser.add_argument('--select_score_type', type=str, default='both', help='cls/multi-hypersphere/both')
    parser.add_argument('--final_score_type', type=str, default='both',  help='cls/multi-hypersphere/both')
    parser.add_argument('--lambda_cls', type=float, default=0.9) 
    parser.add_argument('--lambda_magnitude', type=float, default=0) 
    parser.add_argument('--lambda_multi_hypersphere', type=float, default=0.1)
    parser.add_argument('--miltopk', type=int, default=1)
    parser.add_argument('--classification_level', default='action', type=str, help='session/action')
    parser.add_argument('--prediction_mode', default='post_hoc', type=str, help='post_hoc/real_time')
    parser.add_argument('--mil_dropout', type=float, default=0, help='mil dropout rate')
    parser.add_argument("--smooth_loss_weight", type=float, default=0.01) 
    parser.add_argument("--sparse_loss_weight", type=float, default=0.1)
    parser.add_argument("--const_loss_weight", type=float, default=0)
    parser.add_argument("--separateness_loss_weight", type=float, default=0.5)
    parser.add_argument("--neg_feedback_loss_weight", type=float, default=0)
    parser.add_argument("--prototypes_loss_weight", type=float, default=1.0)
    parser.add_argument("--gcpct_loss_weight", type=float, default=0)
    parser.add_argument("--hard_loss_weight", type=float, default=1.0)
    parser.add_argument("--all_loss_weight", type=float, default=1.0)
    parser.add_argument('--forward_passes', default=5, type=int) 
    parser.add_argument('--ema_alpha', type=float, default=0.999)
    parser.add_argument("--T", type=float, default=0.5) 
    parser.add_argument('--ema_p', type=float, default=0.999) 
    parser.add_argument('--optim', default='AdamW', type=str, help='AdamW/SGD')
    parser.add_argument('--num_classes', default=1, type=int)
    parser.add_argument('--use_quantile', default=False, action='store_true') 
    parser.add_argument('--clip_thresh', default=False, action='store_true') 
    parser.add_argument('--mid_loss_weight', type=float, default=1.0) 
    parser.add_argument('--clip_grad', type=float, default=0) 
    parser.add_argument('--high_confidence_threshold_rate', type=float, default=0.1) 
    parser.add_argument('--mid_confidence_threshold_rate', type=float, default=0.1) 
    parser.add_argument('--binary_thresh', type=float, default=0.5)  
    parser.add_argument('--tau_c', type=float, default=0.5) 
    parser.add_argument('--tau_sim', type=float, default=0.9)  
    parser.add_argument("--plot_cluster", default=False, action='store_true')
    parser.add_argument("--check_multi_hyperspheres", default=False, action='store_true')
    parser.add_argument("--continuetrain_MIL", default=False, action='store_true')
    parser.add_argument("--continuetrain_debias", default=False, action='store_true')
    parser.add_argument("--use_nni", default=False, action='store_true')
    parser.add_argument("--use_detach", default=False, action='store_true')
    parser.add_argument("--use_scheduler", default=True)
    parser.add_argument('--rate_warm_up', type=float, default=1.0) 
    parser.add_argument("--debug", default=False, action='store_true')
    parser.add_argument("--use_att_cons", default=False, action='store_true')

    return parser.parse_args()


def init_center_c(train_loader, net, device, eps=0.01):
    """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
    n_samples = 0
    lc = torch.zeros(net.d_model).to(device)
    t = tqdm(train_loader, desc='Initialize hypersphere center c')
    net.eval()
    has_al = True
    with torch.no_grad():
        for idx, batch in enumerate(t):
            batch = collate_fn(batch, has_al=has_al)
            for k, v in batch.items():
                batch[k] = batch[k].to(device)
            model_output = model(batch['hist_activity'], batch['max_len'])
            lens = batch['max_len'].unsqueeze(1)
            bs, max_len = batch['hist_activity'].shape
            padding_mask = torch.arange(max_len).expand(bs, max_len).to(device) < lens  
            out = model_output['feature'][padding_mask] 
            n_samples += 1
            lc += torch.mean(out, dim=0)
    lc /= n_samples
    lc[(abs(lc) < eps) & (lc < 0)] = -eps
    lc[(abs(lc) < eps) & (lc > 0)] = eps
    return lc


if __name__ == '__main__':
    args = ArgumentParser()
    seed_everything(args.seed)
    device = torch.device(args.device)
    args.lambda_multi_hypersphere = 1 - args.lambda_cls
    if not os.path.isdir(os.path.join(args.root, 'output')):
        os.makedirs(os.path.join(args.root, 'output'))
    logger = get_logger(os.path.join(args.root, 'output/temp.log'))
    logger.info(args)
    print("start load data")
    num_features, cat_features, seq_features, encoders, df = split_train_val_with_date(args.srcroot, args.root,
                                                                                            sp='2011-01-01 00:00:00')
    print("load data over")
    dftrain, dfval = df
    if args.stage == 'mil':
        assert args.UnderSampling == 1 or args.OverSampling == 1 
    Y = dftrain['session_label']
    X = dftrain.drop('session_label', axis=1)
    logger.info("positive : negtive = %d : %d" % (Y[Y == 1].count(), Y[Y == 0].count())) 
    if args.UnderSampling != -1: 
        under = RandomUnderSampler(sampling_strategy=args.UnderSampling)
        X, Y = under.fit_resample(X, Y)
        logger.info(
            "After UnderSampling, positive : negtive = %d : %d" % (Y[Y == 1].count(), Y[Y == 0].count()))
    if args.OverSampling != -1: 
        over = RandomOverSampler(sampling_strategy=args.OverSampling)
        X, Y = over.fit_resample(X, Y)
        logger.info("After OverSampling, positive : negtive = %d : %d" % (Y[Y == 1].count(), Y[Y == 0].count()))
    if args.UnderSampling != -1 or args.OverSampling != -1:
        dftrain = X
        dftrain['session_label'] = Y
    dftrain = dftrain[dftrain['hist_activity'].apply(len) > args.miltopk]
    dftrain_a = dftrain[dftrain['session_label'] == 1]
    dftrain_n = dftrain[dftrain['session_label'] == 0]
    dfval_a = dfval[dfval['session_label'] == 1]
    label_col = ['acts_labels', 'session_label']
    cat_nums = {k: v.dimension() for k, v in encoders.items()} 
    ds_train_a = Df2Dataset(dftrain_a, num_features, cat_features, seq_features, encoders, label_col=label_col)
    ds_train_n = Df2Dataset(dftrain_n, num_features, cat_features, seq_features, encoders, label_col=label_col)
    ds_train = Df2Dataset(dftrain, num_features, cat_features, seq_features, encoders, label_col=label_col)
    ds_val = Df2Dataset(dfval, num_features, cat_features, seq_features, encoders, label_col=label_col)
    ds_val_a = Df2Dataset(dfval_a, num_features, cat_features, seq_features, encoders, label_col=label_col)
    args.bsdiv2 = args.batch_size // 2
    dl_train_n = DataLoader(ds_train_n, batch_size=args.bsdiv2, num_workers=args.num_workers, shuffle=True, drop_last=True)
    dl_train_a = DataLoader(ds_train_a, batch_size=args.bsdiv2, num_workers=args.num_workers, shuffle=True, drop_last=True)
    dl_train_random = DataLoader(ds_train, batch_size=args.bsdiv2, num_workers=args.num_workers, shuffle=True, drop_last=True)
    dl_val = DataLoader(ds_val, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    dl_val_a = DataLoader(ds_val_a, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    num_embeddings = (cat_nums['hist_activity'] - 3) * 24 + 25 
    if args.prediction_mode == 'post_hoc':
        num_embeddings += 1 
    model = SeqwithClassifier(num_embeddings, seq_embedding_size=args.seq_embedding_size, hidden_size=args.hidden_size, dropout=args.model_dropout, reduction=args.reduction if args.reduction != 'None' else None,
                            LayerNorm=args.use_layernorm, encoder_type=args.encoder_type, num_layers=args.num_layers, num_class=1, prediction_mode=args.prediction_mode, num_spheres=args.num_spheres, use_prototypes=args.use_prototypes, similarity_type=args.similarity_type, use_att_cons=args.use_att_cons)
    model = model.to(device)
    ema_model = SeqwithClassifier(num_embeddings, seq_embedding_size=args.seq_embedding_size, hidden_size=args.hidden_size, dropout=args.model_dropout, reduction=args.reduction if args.reduction != 'None' else None,
                            LayerNorm=args.use_layernorm, encoder_type=args.encoder_type, num_layers=args.num_layers, num_class=1, prediction_mode=args.prediction_mode, num_spheres=args.num_spheres, use_prototypes=args.use_prototypes, similarity_type=args.similarity_type)
    ema_model = ema_model.to(device)
    if args.init_prototypes_with_clustering:
        assert args.use_prototypes == True
        normal_vector_save_path = os.path.join(args.root, '{}_{}_{}_normal_vector_save_path.npy'.format(args.srcroot[-4:], args.encoder_type, 'ln' if args.use_layernorm else 'noln'))
        cluster_center_save_path = os.path.join(args.root, '{}_{}_{}_{}clusters_cluster_center_save_path.npy'.format(args.srcroot[-4:], args.encoder_type, 'ln' if args.use_layernorm else 'noln', args.num_spheres))
        if args.num_spheres == 1:
            lc = init_center_c(dl_train_n, model, device)
            cluster_centers = lc.detach().cpu().numpy()
        else :
            if os.path.exists(cluster_center_save_path):
                cluster_centers = np.load(cluster_center_save_path)
                if args.plot_cluster:
                    xb = np.load(normal_vector_save_path)
            else:
                model.eval()
                logger.info('generating vector for clustering...')
                vector_list = []
                t = tqdm(dl_train_n, desc='generating vector for clustering')
                has_al = True
                for idx, batch in enumerate(t):
                    batch = collate_fn(batch, has_al=has_al)
                    for k, v in batch.items():
                        batch[k] = batch[k].to(device)
                    model_output = model(batch['hist_activity'], batch['max_len'])
                    lens = batch['max_len'].unsqueeze(1) 
                    bs, max_len = batch['hist_activity'].shape
                    padding_mask = torch.arange(max_len).expand(bs, max_len).to(device) < lens 
                    out = model_output['feature'][padding_mask] 
                    out = out.detach().cpu().numpy()
                    vector_list.append(out)
                xb = np.vstack(vector_list)
                logger.info('num of vectors: {}'.format(xb.shape[0]))
                logger.info('saving vectors for clustering to {}'.format(normal_vector_save_path))
                np.save(normal_vector_save_path, xb)
                logger.info('Clustering into {} classes'.format(args.num_spheres))
                samples_indices = np.random.choice(xb.shape[0], 500000, replace=False)
                xb = xb[samples_indices, :]
                logger.info('after sampling, num of vectors: {}'.format(xb.shape[0]))
                km = KMeans(n_clusters=args.num_spheres, n_init=2, max_iter=100, random_state=args.seed).fit(xb)
                cluster_centers = km.cluster_centers_  
                logger.info('saving cluster center to {}'.format(cluster_center_save_path))
                np.save(cluster_center_save_path, cluster_centers)
            model.multi_hyperspheres.prototypes_centers.data.copy_(torch.from_numpy(cluster_centers).float().to(device)) 
        
    miltrainer = MILtrainer(args, model, dl_train_a, dl_train_n, dl_train_random, dl_val, dl_val_a, device, logger, ema_model=ema_model)
    if args.stage == 'occ':
        miltrainer.train_occ()
    elif args.stage == 'mil':
        checkpoint = torch.load(os.path.join(args.root, '{}_{}_None_{}_noln_level_occ_checkpoint.pth'.format(
            args.srcroot[-4:], args.encoder_type,
            args.classification_level)), map_location=device)
        print("occ auc: {}".format(checkpoint['auc']))
        model.load_state_dict(checkpoint['model'], strict=False)
        ema_model.load_state_dict(model.state_dict(), strict=False)
        model = model.to(device)
        ema_model = ema_model.to(device)
        miltrainer.model = model
        miltrainer.ema_model = ema_model
        miltrainer.train_mil()



