#modify coherence and lossFunc MCNet2
#modify attRes -> gateIncrease procedure, introduce optimal transport here MCNet3
#modify attRes gateIncrease ratios during training process MCNet4
#add layerNorm MCNet5
#constrain transport matrix P MCNet6
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
import os
import numpy as np
import random
import math
from omni_anomaly.eval_methods import pot_eval,searchThreshold
import matplotlib.pyplot as plt
from scipy.stats import norm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class myDataset(Dataset):
    def __init__(self,datas,winlen,labels=None,type="train",LLMResponses=None,testMask=None):
        super(myDataset,self).__init__()
        if type=="train" or type=="validate":
            self.x=self.split(datas,winlen)
        elif type=="break":
            self.x,self.y=self.splitTest(datas,winlen,labels)
        elif type=="collaborate":
            self.x,self.llm,self.interDifferences,self.intraDifferences = self.groupSplit(datas, winlen,LLMResponses)
        elif type=="test":
            self.x,self.y,self.llm,self.mask=self.groupSplitTest(datas,winlen,labels,LLMResponses,testMask)

        self.type=type

    def split(self,datas,winlen):
        xs=[]
        for i in range(len(datas)-winlen):
            xs.append(datas[i:i+winlen])
        return xs

    def groupSplit(self,datas,winlen,LLMResponses):
        xs=[]
        llms=[]
        intraDifferences = []
        interDifferences = []
        xPatches = []
        for i in range(0, len(datas) - winlen + 1):
            xPatches.append(datas[i:i + winlen])
        xPatches = torch.stack(xPatches,dim=0)
        for i in range(len(datas)-winlen+1):
            xs.append(datas[i:i+winlen])
            llms.append(LLMResponses[i:i+winlen])
            interDifference = torch.max(torch.mean(np.fabs(datas[i:i + winlen] - datas[i + winlen - 1]), dim=0))
            intraDifference = torch.max(torch.mean(torch.mean(np.fabs(xPatches - datas[i:i + winlen]), dim=0), dim=0))
            interDifferences.append(interDifference)
            intraDifferences.append(intraDifference)
        return xs,llms,interDifferences,intraDifferences

    def splitTest(self,datas,winlen,labels):
        xs=[]
        ys=[]
        for i in range(len(datas)-winlen+1):
            xs.append(datas[i:i+winlen])
            ys.append(labels[i:i+winlen])
        return xs,ys

    def groupSplitTest(self,datas,winlen,labels,LLMResponses,testMask):
        xs = []
        ys = []
        llms=[]
        masks=[]
        testMask=testMask==0.
        for i in range(len(datas) - winlen):
            xs.append(datas[i:i + winlen])
            ys.append(labels[i:i + winlen])
            llms.append(LLMResponses[i:i+winlen])
            masks.append(testMask[i:i+winlen])
        return xs, ys,llms,masks

    def __getitem__(self, item):
        item = item % self.__len__()
        if self.type=="train" or self.type=="validate":
            return self.x[item]
        elif self.type=="break":
            return (self.x[item],self.y[item])
        elif self.type=="collaborate":
            couple=random.randint(0,len(self.x)-1)
            count=0
            while couple==item and count<8:
                couple = random.randint(0, len(self.x) - 1)
                count+=1
            return (self.x[item],self.x[couple],self.llm[item],self.llm[couple],self.interDifferences[item],self.intraDifferences[item])
        elif self.type=="test":
            return (self.x[item],self.y[item],self.llm[item],self.mask[item])
    def __len__(self):
        return len(self.x)



class DecompositionBlock(nn.Module):
    def __init__(self, pars):
        super(DecompositionBlock,self).__init__()
        #self.conv=nn.Conv1d(pars.fealen,pars.fealen,pars.klen,1)
        self.conv=nn.AvgPool1d(pars.klen,1)
        self.klen=pars.klen
    def forward(self,x):#x:batchsize,winlen,fealen
        batch,winlen,fealen=x.shape
        x=x.permute(0,2,1)#batchsize,fealen,winlen
        xPadding=torch.cat([x[:,:,:self.klen-1],x],dim=-1)
        xconv=self.conv(xPadding)
        xdiff=x-xconv
        xconv=xconv.permute(0,2,1)
        xdiff=xdiff.permute(0,2,1)#batchsize,winlen,fealen
        return xdiff,xconv


class GaussianCurve(nn.Module):
    def __init__(self,rows,cols,center,pars):
        super(GaussianCurve,self).__init__()
        self.sigmas=nn.Parameter(torch.ones(rows,1))
        self.rows=rows
        self.cols=cols
        self.center=center
        self.device=pars.device
    def forward(self):
        xs = torch.arange(0, self.cols)
        xs = xs.repeat(self.rows, 1).to(self.device)
        if isinstance(self.center,list):
            centers=torch.tensor(self.center,dtype=torch.float)
            centers=centers.repeat(self.cols,1)
            centers=centers.permute(1,0)
        else:
            centers=torch.ones((self.rows,self.cols))*self.center
        centers=centers.to(self.device)
        gauss=torch.pow(xs-centers,2)#rows,cols
        gauss=gauss.permute(1,0)
        gauss/=-2*torch.pow(self.sigmas,2)
        gauss=torch.exp(gauss)/self.sigmas#cols,rows
        gausSum=gauss.sum(dim=0)
        gauss/=gausSum
        gauss=gauss.permute(1,0)
        return gauss


class gate(nn.Module):
    def __init__(self,pars):
        super(gate,self).__init__()
        self.atts=nn.ModuleList([nn.MultiheadAttention(pars.patchSize, 1, batch_first=True,kdim=pars.patchSize,vdim=1) for i in range(pars.fealen)])
        self.activ=nn.Sigmoid()
        self.patchSize=pars.patchSize
        self.attCurve=GaussianCurve(pars.winlen,pars.winlen,[i for i in range(pars.winlen)],pars)
        self.softmax=nn.Softmax(dim=-1)
        self.device=pars.device
        self.activ2=nn.LeakyReLU()
        self.Wx=nn.Linear(pars.fealen,pars.fealen)
        self.scaler=nn.Parameter(torch.ones(pars.fealen))
        self.bias=nn.Parameter(torch.ones(pars.fealen))
        if pars.omit:
            self.cost=pars.exeTime[1:]
        else:
            self.cost=pars.exeTime

    def getK(self):
        intervals=torch.tensor(self.cost)
        X,Y=torch.meshgrid(intervals,intervals)
        epsilon=0.03**2
        K = torch.exp(-torch.pow(X - Y, 2) / epsilon)
        K=K.to(self.device)
        return K

    def getC(self):
        intervals = torch.tensor(self.cost)
        X, Y = torch.meshgrid(intervals, intervals)
        C=X-Y
        C=C.to(self.device)
        return C

    def forward(self,x):#x:batchsize,winlen,fealen
        x=x.permute(0,2,1)#batch,fealen,winlen
        batchSize,fealen,winlen=x.shape
        xPadding=torch.cat([x[:,:,:self.patchSize],x],dim=-1)
        xExtend=xPadding.unfold(2,self.patchSize+1,1)#batch,fealen,blockNum,patchsize+1
        _,_,blockNum,_=xExtend.shape
        attWeight=[]
        for i,att in enumerate(self.atts):
            _,attWei=att(xExtend[:,i,:,:-1],xExtend[:,i,:,:-1],xExtend[:,i,:,-1:])
            attWeight.append(attWei)
        attWeight=torch.stack(attWeight,dim=0)#fealen,batch,blockNum,blockNum
        attWeight=attWeight.permute(1,0,2,3).reshape(batchSize*fealen,blockNum,blockNum)
        attWeightSave=attWeight.clone()
        attWeightSave=attWeightSave.reshape(batchSize,fealen,winlen,winlen)
        attWeight=attWeight*(1-self.attCurve()) #batch*fealen,winlen (tarlen),winlen (sourLen)
        attWeight=self.softmax(attWeight).permute(1,0,2)#tarLen, batch*fealen, sourLen
        value=xExtend[:,:,:,-1:].permute(0,2,3,1)#batch,blocknum,1,fealen
        value=(self.scaler*value+self.bias).permute(0,3,1,2)
        value=value.reshape(batchSize*fealen,blockNum)
        attRes=(attWeight*value).sum(dim=-1) #tarLen, batch*fealen
        attRes=attRes.permute(1,0)#batch*fealen,tarLen
        attRes=attRes.reshape(batchSize,fealen,blockNum)#batch,fealen,winlen
        attRes=attRes.permute(0,2,1)#batch,winlen,fealen
        return attRes,attWeightSave,10 #batch,winlen,fealen


class MCNetModule(nn.Module):
    def __init__(self,pars):
        super(MCNetModule,self).__init__()
        self.decomposition=DecompositionBlock(pars)
        self.gate=gate(pars)
    def forward(self,x):
        xdiff,xconv=self.decomposition(x)
        seasonalTrend,attWeight,_=self.gate(xconv)
        return seasonalTrend,attWeight #batch,winlen,fealen


class MCNet(nn.Module):
    def __init__(self,pars):
        super(MCNet,self).__init__()
        self.decompositions=nn.ModuleList([MCNetModule(pars) for i in range(pars.moduleNum)])
        self.device=pars.device
        self.sigma=pars.sigma
        self.layerNorm=nn.LayerNorm(pars.fealen)
        self.softmax=nn.Softmax(dim=-1)
    def forward(self,x):
        batch,winlen,fealen=x.shape
        xPres=x
        attWeights=torch.zeros(batch,fealen,winlen,winlen).to(self.device)
        recons=torch.zeros(x.shape).to(self.device)
        reconSingles=[]
        reconAggregs=[]
        for count,decomp in enumerate(self.decompositions):
            recon,attWei=decomp(xPres)
            reconSingles.append(recon.cpu().detach().numpy())
            recons+=recon
            reconAggregs.append(recons.cpu().detach().numpy())
            attWeights+=attWei #batch,fealen,winlen,winlen
            xPres=x-recons
        attWeights=(attWeights.sum(dim=1)).reshape(batch,winlen,winlen)
        reconSingles=np.stack(reconSingles,axis=0)
        reconAggregs=np.stack(reconAggregs,axis=0)
        return recons,reconSingles,reconAggregs


class bridgeModel(nn.Module):
    def __init__(self,pars):
        super(bridgeModel, self).__init__()
        self.W1=nn.Linear(pars.fealen,int(pars.fealen/2))
        self.actv=nn.LeakyReLU()
        self.W2=nn.Linear(int(pars.fealen/2)+3,pars.fealen+3)
        self.W3=nn.Linear(pars.fealen+3,1)
        self.actv2=nn.Sigmoid()
        self.LLMDistribution=pars.LLMDistribution
        self.alignScores=nn.Linear(int(pars.fealen/2)+1,1)

    def GaussianDistribution(self,smallScore,half=True):
        res=0
        for gauss in self.LLMDistribution:
            sigma=gauss[2]
            miu=gauss[1]
            if not half:
                res+=gauss[0]/(math.pow(2*torch.pi,0.5)*sigma)*torch.exp(-torch.pow(smallScore-miu,2)/(2*math.pow(sigma,2)))
            else:
                res += 2*gauss[0] / (math.pow(2 * torch.pi, 0.5) * sigma) * torch.exp(
                    -torch.pow(smallScore - miu, 2) / (2 * math.pow(sigma, 2)))
        return torch.log(res)

    def forward(self,reconsError,smallScore,LLMscore,lambdas):
        reconsFeature=self.actv(self.W1(reconsError))
        smallScore=torch.unsqueeze(smallScore,dim=-1)
        LLMscore=torch.unsqueeze(LLMscore,dim=-1)
        catSmallScore=torch.cat([reconsFeature,smallScore],dim=-1)
        alignSmallScore=self.actv2(self.alignScores(catSmallScore))
        f_smallScore=self.GaussianDistribution(alignSmallScore)
        try:
            catInput=torch.cat([reconsFeature,alignSmallScore,smallScore,LLMscore],dim=-1)
        except:
            print(reconsFeature.shape,alignSmallScore.shape,smallScore.shape,LLMscore.shape)
            exit(0)
        catFeature=self.actv(self.W2(catInput))
        output=self.actv2(self.W3(catFeature))
        output=output.squeeze(dim=-1)
        return output,alignSmallScore,f_smallScore


def collaborateLoss(smallScore1,smallScore2,LLMscore1,LLMscore2,output1,output2,ratios,f_score,alignSocre):
    loss1=-f_score
    smallScore1=torch.squeeze(smallScore1,dim=-1)
    smallScore2=torch.squeeze(smallScore2,dim=-1)
    r1 = ratios[0].unsqueeze(dim=-1)
    r2 = ratios[1].unsqueeze(dim=-1)
    loss2=-r1*(smallScore2-smallScore1)*(output2-output1)-r2*(LLMscore2-LLMscore1)*(output2-output1)
    return loss2.sum()+loss1.sum()

def lossFunc(recons,x):#recons,x:batch,winlen,fealen; coherence:batch,winlen
    error=torch.pow(x-recons,2)
    error=error.sum(dim=-1)
    error=error.sum(dim=-1)
    error=error.mean()
    return error #error


def train(dataloader,model,loss_fn,parameters,optimizer,iterations):#optimizer!
    size = len(dataloader.dataset)
    num_batches=len(dataloader)
    model.train()
    for batch, x in enumerate(dataloader):
        x = x.to(parameters.device)
        xoriginal=x.clone()
        recons,_,_ = model(x)
        loss = loss_fn(recons,xoriginal)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 2 == 0:
            loss, current = loss.item(), batch * len(x)
            # plot(model,X,y,pLen)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def collaborate(dataloader,base_model,bridge_model,parameters,optimizer,lambdas,smallLoss):#optimizer!
    size = len(dataloader.dataset)
    num_batches=len(dataloader)
    base_model.eval()
    bridge_model.train()
    aggMean=0
    aggStd=0
    loss=torch.zeros(1).to(parameters.device)
    alignSmallScores=[]
    skanScores=[]
    for gauss in parameters.LLMDistribution:
        aggMean+=gauss[0]*gauss[1]
        aggStd+=math.pow(gauss[0],2)*gauss[2]
    for batch, (x, xcollaborate, llm, llmcollaborate,_,_) in enumerate(dataloader):
        x, xcollaborate, llm, llmcollaborate = x.to(parameters.device), xcollaborate.to(parameters.device), \
                                               llm.to(parameters.device), llmcollaborate.to(parameters.device)
        with torch.no_grad():
            recons, _, _ = base_model(x)
        score1 = smallLoss(x, recons, parameters)
        skanScores.append(score1)
    skanScores=torch.cat(skanScores,dim=0)
    maxScore=torch.max(skanScores)
    minScore=torch.min(skanScores)
    for batch, (x,xcollaborate,llm,llmcollaborate,inter,intra) in enumerate(dataloader):
        x,xcollaborate,llm,llmcollaborate,inter,intra = x.to(parameters.device),xcollaborate.to(parameters.device),\
                                            llm.to(parameters.device),llmcollaborate.to(parameters.device),\
                                            inter.to(parameters.device),intra.to(parameters.device)
        ratios = [inter / (inter + intra),intra / (inter + intra)]
        with torch.no_grad():
            recons,_,_ = base_model(x)
            if parameters.needRecons:
                recons=reconAdjustTorch(recons,x)
        reconsError1=torch.pow(recons-x,2)
        score1=(smallLoss(x,recons,parameters)-minScore)/torch.pow(torch.abs(maxScore-minScore+0.0000001),lambdas[4])
        with torch.no_grad():
            recons2,_,_=base_model(xcollaborate)
        reconsError2=torch.pow(recons2-xcollaborate,2)
        score2=(smallLoss(xcollaborate,recons2,parameters)-minScore)/torch.pow(torch.abs(maxScore-minScore+0.0000001),lambdas[4])
        bridgeScore1,alignSmallScore1,f_score1=bridge_model(reconsError1,score1,llm,lambdas)
        bridgeScore2,alignSmallScore2,f_score2=bridge_model(reconsError2,score2,llmcollaborate,lambdas)
        loss+=collaborateLoss(score1,score2,llm,llmcollaborate,bridgeScore1,bridgeScore2,ratios,
                             f_score1,alignSmallScore1)
        alignSmallScores.append(alignSmallScore1[:,-1])
    alignSmallScores=torch.cat(alignSmallScores,dim=0)
    loss+=torch.pow(alignSmallScores.mean()-aggMean,2)*lambdas[3]
    loss+=torch.pow(alignSmallScores.std()-aggStd,2)*lambdas[3]
    loss/=num_batches
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    alignSmallScores = alignSmallScores.cpu().detach().numpy()
    return alignSmallScores



def collaborate_val(dataloader,base_model,bridge_model,parameters,lambdas,smallLoss):#optimizer!
    size = len(dataloader.dataset)
    num_batches=len(dataloader)
    base_model.eval()
    bridge_model.eval()
    test_loss=0
    aggMean = 0
    alignScores = []
    skanScores= []
    bridigeScores=[]
    for gauss in parameters.LLMDistribution:
        aggMean += gauss[0] * gauss[1]
    for batch, (x, xcollaborate, llm, llmcollaborate,_,_) in enumerate(dataloader):
        x, xcollaborate, llm, llmcollaborate = x.to(parameters.device), xcollaborate.to(parameters.device), \
                                               llm.to(parameters.device), llmcollaborate.to(parameters.device)
        with torch.no_grad():
            recons, _, _ = base_model(x)
        score1 = smallLoss(x, recons, parameters)
        skanScores.append(score1)
    skanScores = torch.cat(skanScores, dim=0)
    maxScore = torch.max(skanScores)
    minScore=torch.min(skanScores)
    with torch.no_grad():
        for batch, (x,xcollaborate,llm,llmcollaborate,inter,intra) in enumerate(dataloader):
            x,xcollaborate,llm,llmcollaborate,inter,intra = x.to(parameters.device),xcollaborate.to(parameters.device),\
                                                llm.to(parameters.device),llmcollaborate.to(parameters.device),\
                                                inter.to(parameters.device),intra.to(parameters.device)
            with torch.no_grad():
                recons,_,_ = base_model(x)
                if parameters.needRecons:
                    recons = reconAdjustTorch(recons, x)
            reconsError1=torch.pow(recons-x,2)
            score1=(smallLoss(x,recons,parameters)-minScore)/torch.pow(torch.abs(maxScore-minScore+0.0000001),lambdas[4])
            with torch.no_grad():
                recons2,_,_=base_model(xcollaborate)
            reconsError2=torch.pow(recons2-xcollaborate,2)
            score2=(smallLoss(xcollaborate,recons2,parameters)-minScore)/torch.pow(torch.abs(maxScore-minScore+0.0000001),lambdas[4])
            bridgeScore1,alignScore1,f_score1=bridge_model(reconsError1,score1,llm,lambdas)
            bridgeScore2,alignScore2,f_score2=bridge_model(reconsError2,score2,llmcollaborate,lambdas)
            ratios=[inter/(inter+intra),intra/(inter+intra)]
            loss=collaborateLoss(score1,score2,llm,llmcollaborate,bridgeScore1,bridgeScore2,ratios,f_score1,alignScore1)
            test_loss+=loss
            alignScores.append(alignScore1)
            bridigeScores.append(bridgeScore1)
    alignScores=torch.cat(alignScores,dim=-1)
    bridgeScores=torch.cat(bridigeScores,dim=-1).reshape([-1])
    test_loss+=torch.pow(alignScores.mean()-aggMean,2)*lambdas[3]
    print("validate error:%f"%test_loss)
    return test_loss/num_batches,bridgeScores


def test(dataloader,base_model,bridge_model,parameters,lambdas,smallLoss,datasetID,dataset,valscores):
    num_batches = len(dataloader)
    base_model.eval()
    bridge_model.eval()
    scores=[]
    smallScores=[]
    LLMscores=[]
    labels=[]
    skanScores=[]
    masks=[]
    alignScores=[]
    with torch.no_grad():
        for x,y,llm,mask in dataloader:
            x,y,llm,mask = x.to(parameters.device),y.to(parameters.device),llm.to(parameters.device),mask.to(parameters.device)
            recons, _, _ = base_model(x)
            score1 = smallLoss(x, recons, parameters)
            skanScores.append(score1)
        skanScores = torch.cat(skanScores, dim=0)
        maxScore = torch.max(skanScores)
        minScore=torch.min(skanScores)
        for x,y,llm,mask in dataloader:
            x,y,llm,mask = x.to(parameters.device),y.to(parameters.device),llm.to(parameters.device),mask.to(parameters.device)
            recons, _, _ = base_model(x)
            if parameters.needRecons:
                recons = reconAdjustTorch(recons, x)
            reconsError=torch.pow(recons-x,2)
            Smallscore=(smallLoss(x,recons,parameters)-minScore)/torch.pow(torch.abs(maxScore-minScore+0.0000001),lambdas[4])
            score,alignScore,_=bridge_model(reconsError,Smallscore,llm,lambdas)
            scores.append(score[:,-1])
            alignScores.append(alignScore[:,-1])
            labels.append(y[:,-1])
            smallScores.append(Smallscore[:,-1])
            LLMscores.append(llm[:,-1])
            masks.append(mask[:,-1])
        masks=torch.cat(masks,dim=0).cpu().detach().numpy()
        scores=torch.cat(scores,dim=0).cpu().detach().numpy()
        smallScores=torch.cat(smallScores,dim=0).cpu().detach().numpy()
        LLMscores=torch.cat(LLMscores,dim=0).cpu().detach().numpy()
        labels=torch.cat(labels,dim=0).cpu().detach().numpy()
        valscores = valscores.cpu().detach().numpy()
        scores=scores[masks]
        smallScores=smallScores[masks]
        LLMscores=LLMscores[masks]
        labels=labels[masks]
        #pot_result=pot_eval(valscores,-scores,labels,q=1e-3,level=0.08) #adjust q and level for different subsets
        pot_result = searchThreshold(-scores, labels)
        precision = pot_result['pot-precision']
        recall = pot_result['pot-recall']
        F1 = pot_result['pot-f1']
        threshold=pot_result['pot-threshold']
        small_result=searchThreshold(-smallScores,labels)
        LLM_result=searchThreshold(-LLMscores,labels)
    print("---------collaboration-----------")
    print("precision:%f"%precision)
    print("recall:%f"%recall)
    print("F1:%f"%F1)
    print("----------small model------------")
    print("precision:%f" % small_result['pot-precision'])
    print("recall:%f" % small_result['pot-recall'])
    print("F1:%f" % small_result['pot-f1'])
    print("----------LLM model------------")
    print("precision:%f" % LLM_result['pot-precision'])
    print("recall:%f" % LLM_result['pot-recall'])
    print("F1:%f" % LLM_result['pot-f1'])


    return precision,recall,F1,small_result['pot-precision'],small_result['pot-recall'],small_result['pot-f1'],\
           LLM_result['pot-precision'],LLM_result['pot-recall'],LLM_result['pot-f1']

def validate(dataloader,model,loss_fn,parameters):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for x in dataloader:
            x = x.to(parameters.device)
            recons,_,_ = model(x)
            loss = loss_fn(recons,x)
            test_loss+=loss
    test_loss /= num_batches

    print(f"Validate Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss

def getAnomScore3(x,recons,pars):
    scores=x-recons
    scores=scores.sum(axis=-1)
    scores=list(scores)
    return scores

def getAnomScore2(x,recons,pars):
    if pars.omit:
        intervals = np.array(pars.exeTime[1:], dtype=np.float32)
    else:
        intervals = np.array(pars.exeTime, dtype=np.float32)
    scores=(x-recons)*intervals
    scores=scores.sum(axis=-1)
    #scores=list(scores)
    return scores

def getAnomScoreStandard(x,recons,pars):
    return np.power(x-recons,2).sum(axis=-1)

def getAnomScoreDivide(x,recons,pars):
    if pars.omit:
        intervals = np.array(pars.exeTime[1:], dtype=np.float64)
    else:
        intervals=np.array(pars.exeTime,dtype=np.float64)
    scores = (x - recons) * intervals
    return scores




def slideWindow(scores,coherences,pars):
    winlen=pars.slidewinlen
    scores=scores[:winlen]+scores
    coherences=np.hstack((coherences[:winlen],coherences))
    nscores=[]
    coherence=torch.tensor(coherences)
    for i in range(len(scores)-winlen):
        weight=torch.softmax(coherence[i:i+winlen],dim=0).numpy()
        nscores.append(scores[i]-np.sum(scores[i:i+winlen]*weight))
    return nscores


def reconAdjust(recons,x):
    print(recons.shape)
    recons=(recons-recons.mean(axis=0,keepdims=True))/(recons.std(axis=0,keepdims=True)+0.00001)
    recons=(recons*x.std(axis=0,keepdims=True))+x.mean(axis=0,keepdims=True)
    return recons

def reconAdjustTorch(recons,x):
    recons = (recons - recons.mean(dim=0, keepdims=True)) / (recons.std(dim=0, keepdims=True) + 0.00001)
    recons = (recons * x.std(dim=0, keepdims=True)) + x.mean(dim=0, keepdims=True)
    return recons

def plotSeries(series,colors):
    print(series.shape)
    series=series.transpose(1,0)
    for sery, color in zip(series,colors):
        plt.plot(sery,color=color)

def breakPoint(dataloader,model,loss_fn,parameters,smallLoss):
        num_batches = len(dataloader)
        print(num_batches)
        model.eval()
        test_loss = 0
        labels=[]
        reconSeq=[]
        origiSeq=[]
        diffSeq=[]
        labelsShow=[]
        reconSigs=[]
        reconAggs=[]
        with torch.no_grad():
            for x,y in dataloader:
                x, y = x.to(parameters.device), y.to(parameters.device)
                labelsShow.append(y[:, -1])
                y = y == 1
                labels.append(y[:, -2:].any(dim=-1))
                recons,reconSig,reconAgg = model(x)
                test_loss += loss_fn(recons,x)
                reconSeq.append(recons[:,-1,:])
                origiSeq.append(x[:,-1,:])
                diffSeq.append(x[:,-1,:]-recons[:,-1,:])
                reconSigs.append(reconSig[:,:,-1])
                reconAggs.append(reconAgg[:,:,-1])
        test_loss /= num_batches
        reconSeq=torch.cat(reconSeq,dim=0).cpu().detach().numpy()
        origiSeq=torch.cat(origiSeq,dim=0).cpu().detach().numpy()
        reconSeq=reconAdjust(reconSeq,origiSeq)
        reconSigs=np.concatenate(reconSigs,axis=1)
        reconAggs=np.concatenate(reconAggs,axis=1)
        scores=smallLoss(origiSeq,reconSeq,parameters)
        print("score shape",len(scores))
        print("recon shape",reconSeq.shape)
        scores=np.array(scores)
        labels=torch.cat(labels,dim=0).cpu().detach().numpy()
        labelsShow=torch.cat(labelsShow,dim=0).cpu().detach().numpy()
        pot_result = searchThreshold(-scores, labels)
        precision = pot_result['pot-precision']
        recall = pot_result['pot-recall']
        F1 = pot_result['pot-f1']
        threshold = pot_result['pot-threshold']
        accuracy = torch.true_divide(pot_result['pot-TP'] + pot_result['pot-TN'],
                                     pot_result['pot-TP'] + pot_result['pot-TN']
                                     + pot_result['pot-FP'] + pot_result['pot-FN']).item()


        print(f"Test Error: \n  Avg loss: {test_loss:>8f} \n")
        print("precision:%.6f, recall:%.6f, F1 score:%.6f, accuracy:%.6f\n" % (precision, recall, F1, accuracy))
        print("average score:%f"%np.mean(scores))
        return test_loss, precision, recall, F1, accuracy,scores

def loadModel(path,parameters,type):
    if type=="base":
        model = MCNet(parameters)
        model.load_state_dict(torch.load(path))
    else:
        model=bridgeModel(parameters)
        model.load_state_dict(torch.load(path))
    return model

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def getSeed():
    seed=int(time.time()*1000)%(2**32-1)
    return seed


def detectAnomaly(trainData,valDatas,testDatas,testLabels,trainr,valdatar,testr,args,dataset,lambdas,testMask,datasetID=0):
    trainData=torch.tensor(trainData,dtype=torch.float)
    valDatas=torch.tensor(valDatas,dtype=torch.float)
    testDatas=torch.tensor(testDatas,dtype=torch.float)
    testLabels=torch.tensor(testLabels,dtype=torch.float)
    trainr=torch.tensor(trainr,dtype=torch.float)
    valdatar=torch.tensor(valdatar,dtype=torch.float)
    testr=torch.tensor(testr,dtype=torch.float)
    seed = 1282028438#getSeed()#
    setup_seed(seed)
    if args.smallLoss=="loss2":
        smallLoss=getAnomScore2
    else:
        smallLoss=getAnomScoreStandard
    loadMod = args.loadMod != 0
    needTrain = args.needTrain != 0
    trainDataset=myDataset(trainData,args.winlen)
    valDataset=myDataset(valDatas,args.winlen,type="validate")
    breakDataset=myDataset(testDatas,args.winlen,testLabels,"break")
    collaborateDataset=myDataset(trainData,args.winlen,type="collaborate",LLMResponses=trainr)
    collaValDataset=myDataset(valDatas,args.winlen,type="collaborate",LLMResponses=valdatar)
    testDataset=myDataset(testDatas,args.winlen,testLabels,type="test",LLMResponses=testr,testMask=testMask)
    trainDataLoader=DataLoader(trainDataset,batch_size=args.batchSize,shuffle=True)
    valDataLoader=DataLoader(valDataset,batch_size=args.batchSize,shuffle=True)
    breakDataLoader=DataLoader(breakDataset,batch_size=args.batchSize,shuffle=False)
    collaborateDataloader=DataLoader(collaborateDataset,batch_size=args.batchSize,shuffle=True)
    collaValDataloader=DataLoader(collaValDataset,batch_size=args.batchSize,shuffle=False)
    testDataloader=DataLoader(testDataset,batch_size=args.batchSize,shuffle=False)
    dirName = "MCNet_"+str(dataset)
    if not os.path.exists(dirName):
        os.mkdir(dirName)
    base_modelPath = dirName + "/MCNet_base"+str(seed)+".pth"
    bridge_modelPath = dirName + "/MCNet_bridge"+str(seed)+".pth"
    if not loadMod:
        base_model=MCNet(args).to(args.device)
        bridge_model=bridgeModel(args).to(args.device)
    else:
        base_model=loadModel(base_modelPath,args,"base").to(args.device)
        bridge_model=loadModel(bridge_modelPath,args,"bridge").to(args.device)
    loss_fn=lossFunc
    base_epochs = 3
    bridge_epochs=150
    base_optimizer = torch.optim.Adam(base_model.parameters(), lr=args.trlr,weight_decay=0.1)
    bridge_optimizer=torch.optim.Adam(bridge_model.parameters(),lr=args.colr,weight_decay=0.1)
    if needTrain:
        best_loss = 9999999999
        last_loss = 999999999
        count = 0
        torch.save(base_model.cpu().state_dict(), base_modelPath)
        base_model = base_model.to(args.device)
        print("Saved PyTorch Model State to " + base_modelPath)
        for t in range(base_epochs):
            print(f"Epoch {t + 1}\n-------------------------------")
            train(trainDataLoader,base_model,loss_fn,args,base_optimizer,t)
            test_loss=validate(valDataLoader,base_model,loss_fn,args)
            if math.isnan(test_loss):
                break
            if last_loss < test_loss:
                count += 1
            else:
                count = 0
            if count >= 2 or math.isnan(test_loss):
                break
            last_loss = test_loss
            if test_loss < best_loss:
                best_loss = test_loss
                torch.save(base_model.cpu().state_dict(), base_modelPath)
                base_model = base_model.to(args.device)
                print("Saved PyTorch Model State to " + base_modelPath)
    base_model = loadModel(base_modelPath, args,"base").to(args.device)
    if args.needBreak:
        test_loss, precision, recall, F1, accuracy,scores = breakPoint(breakDataLoader, base_model, loss_fn, args,smallLoss)
        maxscore=np.max(scores)
        print("max score:%f"%maxscore)
        return precision,recall,F1,scores
    if args.needCollaborate:
        best_loss = 9999999999
        last_loss = 999999999
        count = 0
        torch.save(bridge_model.cpu().state_dict(), bridge_modelPath)
        bridge_model = bridge_model.to(args.device)
        print("Saved PyTorch Model State to " + bridge_modelPath)
        for t in range(bridge_epochs):
            print(f"Epoch {t + 1}\n-------------------------------")
            collaborate(collaborateDataloader, base_model, bridge_model, args, bridge_optimizer, lambdas,smallLoss)
            test_loss,bridgeScores = collaborate_val(collaValDataloader, base_model,bridge_model, args,lambdas,smallLoss)
            if math.isnan(test_loss):
                break
            if last_loss < test_loss:
                count += 1
            else:
                count = 0
            if count >= args.earlyQuit or math.isnan(test_loss):
                break
            last_loss = test_loss
            if test_loss < best_loss:
                best_loss = test_loss
                torch.save(bridge_model.cpu().state_dict(), bridge_modelPath)
                bridge_model = bridge_model.to(args.device)
                print("Saved PyTorch Model State to " + bridge_modelPath)
    bridge_model=loadModel(bridge_modelPath,args,"bridge")
    precision,recall,F1,smp,smr,smf,llmp,llmr,llmf=test(testDataloader,base_model,bridge_model,args,lambdas,smallLoss,datasetID,dataset,bridgeScores)
    with open(dirName + "/res" + str(seed) + ".csv", "w") as f:
        f.write("%f,%f,%f\n" % (precision, recall, F1))
    with open(dirName + "/config" + str(seed) + ".txt", "w") as f:
        f.write(str(args))
    return precision,recall,F1,smp,smr,smf,llmp,llmr,llmf