
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
import os
import numpy as np
import random
import math
from Utils.eval_methods import searchThreshold
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class myDataset(Dataset):
    def __init__(self,datas,winlen,m,labels=None,type="train"):
        super(myDataset,self).__init__()
        if type=="train" or type=="validate":
            self.x=self.split(datas,winlen)
        elif type=="test":
            self.x,self.y=self.splitTest(datas,winlen,labels)

        self.type=type
        self.m=m

    def split(self,datas,winlen):
        xs=[]
        for i in range(len(datas)-winlen):
            xs.append(datas[i:i+winlen])
        xs=torch.stack(xs,dim=0)
        return xs

    def splitTest(self,datas,winlen,labels):
        xs=[]
        ys=[]
        for i in range(len(datas)-winlen+1):
            xs.append(datas[i:i+winlen])
            ys.append(labels[i:i+winlen])
        xs=torch.stack(xs,dim=0)
        ys=torch.stack(ys,dim=0)
        return xs,ys

    def __getitem__(self, item):
        item = item % self.__len__()
        if self.type=="train" or self.type=="validate":
            return self.x[item:item+self.m]
        elif self.type=="test":
            return (self.x[item:item+self.m],self.y[item+self.m-1])
    def __len__(self):
        return len(self.x)-self.m+1



class DecompositionBlock(nn.Module):
    def __init__(self, pars):
        super(DecompositionBlock,self).__init__()
        #self.conv=nn.Conv1d(pars.fealen,pars.fealen,pars.klen,1)
        self.conv=nn.AvgPool1d(pars.klen,1)
        self.klen=pars.klen
    def forward(self,x):#x:batchsize,winlen,fealen
        batch,winlen,fealen=x.shape
        x=x.permute(0,2,1)#batchsize,fealen,winlen
        xPadding=torch.cat([x[:,:,:self.klen-1],x],dim=-1)
        xconv=self.conv(xPadding)
        xdiff=x-xconv
        xconv=xconv.permute(0,2,1)
        xdiff=xdiff.permute(0,2,1)#batchsize,winlen,fealen
        return xdiff,xconv


class GaussianCurve(nn.Module):
    def __init__(self,rows,cols,center,pars):
        super(GaussianCurve,self).__init__()
        self.sigmas=nn.Parameter(torch.ones(rows,1))
        self.rows=rows
        self.cols=cols
        self.center=center
        self.device=pars.device
    def forward(self):
        xs = torch.arange(0, self.cols)
        xs = xs.repeat(self.rows, 1).to(self.device)
        if isinstance(self.center,list):
            centers=torch.tensor(self.center,dtype=torch.float)
            centers=centers.repeat(self.cols,1)
            centers=centers.permute(1,0)
        else:
            centers=torch.ones((self.rows,self.cols))*self.center
        centers=centers.to(self.device)
        gauss=torch.pow(xs-centers,2)#rows,cols
        gauss=gauss.permute(1,0)
        gauss/=-2*torch.pow(self.sigmas,2)
        gauss=torch.exp(gauss)/self.sigmas#cols,rows
        gausSum=gauss.sum(dim=0)
        gauss/=gausSum
        gauss=gauss.permute(1,0)
        return gauss


class gate(nn.Module):
    def __init__(self,pars):
        super(gate,self).__init__()
        self.atts=nn.ModuleList([nn.MultiheadAttention(pars.patchSize, 1, batch_first=True,kdim=pars.patchSize,vdim=1) for i in range(pars.fealen)])
        #self.att1 = nn.MultiheadAttention(pars.patchSize, 1, batch_first=True,kdim=pars.patchSize,vdim=1)
        #self.att2 = nn.MultiheadAttention(pars.patchSize, 1, batch_first=True, kdim=pars.patchSize, vdim=1)
        self.activ=nn.Sigmoid()
        self.patchSize=pars.patchSize
        self.attCurve=GaussianCurve(pars.winlen,pars.winlen,[i for i in range(pars.winlen)],pars)
        self.softmax=nn.Softmax(dim=-1)
        self.device=pars.device
        self.activ2=nn.LeakyReLU()
        self.Wx=nn.Linear(pars.fealen,pars.fealen)
        self.scaler=nn.Parameter(torch.ones(pars.fealen))
        self.bias=nn.Parameter(torch.ones(pars.fealen))
        if pars.omit:
            self.cost=pars.exeTime[1:]
        else:
            self.cost=pars.exeTime
        #self.u=nn.Parameter(torch.ones(pars.fealen))
        #self.v=nn.Parameter(torch.ones(pars.fealen))

    def getK(self):
        intervals=torch.tensor(self.cost)
        X,Y=torch.meshgrid(intervals,intervals)
        epsilon=0.03**2
        K = torch.exp(-torch.pow(X - Y, 2) / epsilon)
        K=K.to(self.device)
        return K

    def getC(self):
        intervals = torch.tensor(self.cost)
        X, Y = torch.meshgrid(intervals, intervals)
        C=X-Y
        C=C.to(self.device)
        return C

    def forward(self,x):#x:batchsize,winlen,fealen
        x=x.permute(0,2,1)#batch,fealen,winlen
        batchSize,fealen,winlen=x.shape
        xPadding=torch.cat([x[:,:,:self.patchSize],x],dim=-1)
        xExtend=xPadding.unfold(2,self.patchSize+1,1)#batch,fealen,blockNum,patchsize+1
        _,_,blockNum,_=xExtend.shape
        attWeight=[]
        for i,att in enumerate(self.atts):
            _,attWei=att(xExtend[:,i,:,:-1],xExtend[:,i,:,:-1],xExtend[:,i,:,-1:])
            attWeight.append(attWei)
        attWeight=torch.stack(attWeight,dim=0)#fealen,batch,blockNum,blockNum
        attWeight=attWeight.permute(1,0,2,3).reshape(batchSize*fealen,blockNum,blockNum)
        #xExtend=xExtend.reshape(batchSize*fealen,blockNum,self.patchSize+1)
        #_,attWeight=self.att1(xExtend[:,:,:-1],xExtend[:,:,:-1],xExtend[:,:,-1:])#batchSize*fealen,blockNum,1
        attWeightSave=attWeight.clone()
        attWeightSave=attWeightSave.reshape(batchSize,fealen,winlen,winlen)
        attWeight=attWeight*(1-self.attCurve()) #batch*fealen,winlen (tarlen),winlen (sourLen)
        attWeight=self.softmax(attWeight).permute(1,0,2)#tarLen, batch*fealen, sourLen
        #xEmbed=self.linear(xExtend)
        #xEmbed=self.activ2(xEmbed)
        #xExtend = xExtend.reshape(batchSize * fealen, blockNum, self.patchSize + 1)
        value=xExtend[:,:,:,-1:].permute(0,2,3,1)#batch,blocknum,1,fealen
        #value=self.Wx(value).permute(0,3,1,2)
        value=(self.scaler*value+self.bias).permute(0,3,1,2)
        value=value.reshape(batchSize*fealen,blockNum)
        attRes=(attWeight*value).sum(dim=-1) #tarLen, batch*fealen
        attRes=attRes.permute(1,0)#batch*fealen,tarLen
        #attGate,_=self.att2(xExtend[:,-1,:-1],xExtend[:,:-1,:-1],xExtend[:,:-1,-1])#batchSize*fealen,1,1
        attRes=attRes.reshape(batchSize,fealen,blockNum)#batch,fealen,winlen
        attRes=attRes.permute(0,2,1)#batch,winlen,fealen
        #K=self.getK()
        #C=self.getC()
        #P=torch.matmul(torch.diag(self.u),K)
        #P=torch.matmul(P,torch.diag(self.v))
        #gateIncrease=torch.matmul(attRes,P.transpose(1,0))
        #cost=torch.sum(P*C)
        #presIncrease=ratios[0]*attRes+ratios[1]*gateIncrease
        return attRes,attWeightSave,10 #batch,winlen,fealen


class MCNetModule(nn.Module):
    def __init__(self,pars):
        super(MCNetModule,self).__init__()
        self.decomposition=DecompositionBlock(pars)
        self.gate=gate(pars)
    def forward(self,x):
        xdiff,xconv=self.decomposition(x)
        seasonalTrend,attWeight,_=self.gate(xconv)
        return seasonalTrend,attWeight #batch,winlen,fealen


class MCNet(nn.Module):
    def __init__(self,pars):
        super(MCNet,self).__init__()
        self.decompositions=nn.ModuleList([MCNetModule(pars) for i in range(pars.moduleNum)])
        self.device=pars.device
        self.sigma=pars.sigma
        self.layerNorm=nn.LayerNorm(pars.fealen)
        self.softmax=nn.Softmax(dim=-1)
    def forward(self,x):
        batch,winlen,fealen=x.shape
        xPres=x
        attWeights=torch.zeros(batch,fealen,winlen,winlen).to(self.device)
        recons=torch.zeros(x.shape).to(self.device)
        reconSingles=[]
        reconAggregs=[]
        for count,decomp in enumerate(self.decompositions):
            recon,attWei=decomp(xPres)
            reconSingles.append(recon.cpu().detach().numpy())
            recons+=recon
            reconAggregs.append(recons.cpu().detach().numpy())
            attWeights+=attWei #batch,fealen,winlen,winlen
            #rate=0.6**count
            xPres=x-recons
            #xPres=self.layerNorm(xPres)
        attWeights=(attWeights.sum(dim=1)).reshape(batch,winlen,winlen)
        reconSingles=np.stack(reconSingles,axis=0)
        reconAggregs=np.stack(reconAggregs,axis=0)
        return recons,reconSingles,reconAggregs

class RelaModel(nn.Module):
    def __init__(self,pars,relaMat):
        super(RelaModel, self).__init__()
        self.timeDependency=MCNet(pars)
        self.timeDecay=nn.Parameter(torch.ones((pars.impactLength,pars.fealen,pars.fealen)))
        self.Adj=nn.Parameter(relaMat)
        self.Embed=nn.Linear(pars.fealen,pars.fealen)
        self.actv=nn.LeakyReLU()
        self.mus=nn.Parameter(torch.zeros(pars.fealen))
    def forward(self,x):#batch,m,winlen,fealen
        lx=x[:,-1].transpose(2,1)
        lx1=lx.unsqueeze(2)
        lx2=lx.unsqueeze(1)
        smooth=(torch.norm(lx1-lx2,dim=-1)*self.Adj).sum()
        x=self.actv(self.Embed(x))
        x=torch.matmul(x,self.timeDecay)
        x=torch.matmul(x,self.Adj)
        x=torch.sum(x,dim=1)+self.mus
        recons,_,_=self.timeDependency(x)
        return recons,smooth,self.Adj


def lossFunc1(recons,x,smooth,adj,relaMat,lambdas):#recons,x:batch,winlen,fealen; coherence:batch,winlen
    #print("recons",recons.mean(),recons.std())
    #print("x",x.mean(),x.std())
    error=torch.pow(x-recons,2).sum(dim=-1).sum(dim=-1).mean()

    error+=lambdas[0]*torch.pow(adj,2).sum()+lambdas[1]*smooth
    error+=lambdas[2]*torch.pow(relaMat-adj,2).sum()
    return error

def lossFunc2(recons,x):
    error = torch.pow(x - recons, 2)
    # print("error",error.shape,error.mean())
    error = error.sum(dim=-1)
    error = error.sum(dim=-1)
    return error.mean()


def train(dataloader,model,parameters,optimizer_adj,optimizer_theta,relaMat,epoch=0):#optimizer!
    size = len(dataloader.dataset)
    model.train()

    for batch, x in enumerate(dataloader):
        x = x.to(parameters.device)
        xoriginal=x[:,-1]
        recons,smooth,adj = model(x)
        loss = lossFunc1(recons,xoriginal,smooth,adj,relaMat,parameters.lambdas)
        optimizer_adj.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(),2)
        optimizer_adj.step()
        recons,_,_=model(x)
        loss=lossFunc2(recons,xoriginal)
        optimizer_theta.zero_grad()
        loss.backward()
        optimizer_theta.step()

        if batch % 2 == 0:
            loss, current = loss.item(), batch * len(x)
            # plot(model,X,y,pLen)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader,model,parameters,datasetID,dataset):
    num_batches = len(dataloader)
    model.eval()
    scores=[]
    labels=[]
    xs=[]
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(parameters.device),y.to(parameters.device)
            xOriginal=x[:,-1]
            xs.append(xOriginal[:,-1])
            recons, _, adj = model(x)
            if parameters.needRecons:
                recons = reconAdjustTorch(recons, xOriginal)
            reconsError=torch.pow(recons-xOriginal,2).sum(dim=-1)
            scores.append(reconsError[:,-1])
            labels.append(y[:,-1])
        scores=torch.cat(scores,dim=0).cpu().detach().numpy()
        labels=torch.cat(labels,dim=0).cpu().detach().numpy()

        pot_result = searchThreshold(-scores, labels)
        precision = pot_result['pot-precision']
        recall = pot_result['pot-recall']
        F1 = pot_result['pot-f1']

    print("---------Metrics of %s-----------"%dataset)
    print("precision:%f"%precision)
    print("recall:%f"%recall)
    print("F1:%f"%F1)

    return precision,recall,F1

def validate(dataloader,model,parameters):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for x in dataloader:
            x = x.to(parameters.device)
            xOriginal=x[:,-1]
            recons,_,_ = model(x)
            loss = lossFunc2(recons,xOriginal)
            test_loss+=loss
    test_loss /= num_batches

    print(f"Validate Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss




def slideWindow(scores,coherences,pars):
    winlen=pars.slidewinlen
    scores=scores[:winlen]+scores
    coherences=np.hstack((coherences[:winlen],coherences))
    nscores=[]
    coherence=torch.tensor(coherences)
    for i in range(len(scores)-winlen):
        weight=torch.softmax(coherence[i:i+winlen],dim=0).numpy()
        nscores.append(scores[i]-np.sum(scores[i:i+winlen]*weight))
    return nscores


def reconAdjust(recons,x):
    print(recons.shape)
    recons=(recons-recons.mean(axis=0,keepdims=True))/(recons.std(axis=0,keepdims=True)+0.00001)
    recons=(recons*x.std(axis=0,keepdims=True))+x.mean(axis=0,keepdims=True)
    return recons

def reconAdjustTorch(recons,x):
    recons = (recons - recons.mean(dim=0, keepdims=True)) / (recons.std(dim=0, keepdims=True) + 0.00001)
    recons = (recons * x.std(dim=0, keepdims=True)) + x.mean(dim=0, keepdims=True)
    return recons

def plotSeries(series,colors):
    print(series.shape)
    series=series.transpose(1,0)
    for sery, color in zip(series,colors):
        plt.plot(sery,color=color)


def loadModel(path,parameters,relaMat):
    model=RelaModel(parameters,relaMat)
    model.load_state_dict(torch.load(path))
    return model

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def getSeed():
    seed=int(time.time()*1000)%(2**32-1)
    return seed


def detectAnomaly(trainData,valDatas,testDatas,testLabels,args,dataset,relaMat,datasetID=0):
    relaMatClone=relaMat.clone()
    trainData=torch.tensor(trainData,dtype=torch.float)
    valDatas=torch.tensor(valDatas,dtype=torch.float)
    testDatas=torch.tensor(testDatas,dtype=torch.float)
    testLabels=torch.tensor(testLabels,dtype=torch.float)
    seed = getSeed()
    setup_seed(seed)

    loadMod = args.loadMod != 0
    needTrain = args.needTrain != 0
    trainDataset=myDataset(trainData,args.winlen,args.impactLength)
    valDataset=myDataset(valDatas,args.winlen,args.impactLength,type="validate")
    testDataset=myDataset(testDatas,args.winlen,args.impactLength,testLabels,type="test")
    trainDataLoader=DataLoader(trainDataset,batch_size=args.batchSize,shuffle=True)
    valDataLoader=DataLoader(valDataset,batch_size=args.batchSize,shuffle=True)
    testDataloader=DataLoader(testDataset,batch_size=args.batchSize,shuffle=False)
    dirName = "ADSec_"+str(dataset)
    if not os.path.exists(dirName):
        os.mkdir(dirName)
    modelPath = dirName + "/ADSec"+str(seed)+".pth"
    if not loadMod:
        model=RelaModel(args,relaMat).to(args.device)
    else:
        model=loadModel(modelPath,args,relaMat).to(args.device)
    base_epochs = args.epoch
    params_to_adj = [
        param for name, param in model.named_parameters()
        if 'Adj' in name or 'Embed' in name
    ]
    params_other = [param for name, param in model.named_parameters()
        if not('Adj' in name or 'Embed' in name)]
    optimizer_adj = torch.optim.Adam(params_to_adj, lr=args.trlr,weight_decay=0.1)
    optimizer_theta= torch.optim.Adam(params_other,lr=args.trlr,weight_decay=0.1)
    if needTrain:
        best_loss = 9999999999
        last_loss = 999999999
        count = 0
        torch.save(model.cpu().state_dict(), modelPath)
        model = model.to(args.device)
        print("Saved PyTorch Model State to " + modelPath)
        for t in range(base_epochs):
            print(f"Epoch {t + 1}\n-------------------------------")
            train(trainDataLoader,model,args,optimizer_adj,optimizer_theta,relaMatClone,t)
            test_loss=validate(valDataLoader,model,args)
            if math.isnan(test_loss):
                break
            if last_loss < test_loss:
                count += 1
            else:
                count = 0
            if count >= 2 or math.isnan(test_loss):
                break
            last_loss = test_loss
            if test_loss < best_loss:
                best_loss = test_loss
                torch.save(model.cpu().state_dict(), modelPath)
                model = model.to(args.device)
                print("Saved PyTorch Model State to " + modelPath)


    model=loadModel(modelPath,args,relaMat)
    testStart=time.time()
    precision,recall,F1=test(testDataloader,model,args,datasetID,dataset)
    testEnd=time.time()
    print("test time:",testEnd-testStart)
    with open(dirName + "/res" + str(seed) + ".csv", "w") as f:
        f.write("%f,%f,%f\n" % (precision, recall, F1))
    with open(dirName + "/config" + str(seed) + ".txt", "w") as f:
        f.write(str(args))
    return precision,recall,F1