from os import pread
from models.AnomalyTransformer.activation import MultiheadAttention
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from models.mlp import MLP
from models.MaskedMLP import MaskedMLP
import sys
class AnomalyAttention(nn.Module):
    def __init__(self, n_seqs, n_steps, emb_dim=64, n_heads=4, device='cuda:0', task='prediction'):
        super(AnomalyAttention, self).__init__()

        self.n_heads = n_heads 
        self.n_seqs=n_seqs
        self.emb_dim =emb_dim
        self.mha = MultiheadAttention(emb_dim, n_heads, batch_first=True)
        self.task=task
        self.W_sigma= nn.Linear(n_seqs, self.n_heads, bias=False)

        #These are not part of the attention module so we need to make them ourselves
        self.W_q = nn.Linear(n_seqs,emb_dim)
        self.W_k = nn.Linear(n_seqs,emb_dim)
        self.W_v = nn.Linear(n_seqs,emb_dim)
        
        #self.linear = nn.Linear(emb_dim,n_seqs)
        self.linear = nn.Linear(emb_dim,emb_dim)

        self.device=device

    def forward(self, x):
    
        #x:[b, n_steps, n_seqs]
        x = torch.transpose(x,1,2)
        
        #[b,n_steps,emb_dim]
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # [batch_size, n_steps, n_heads]
        sigma = self.W_sigma(x)
        
        # [batch_size, n_heads, n_steps]
        sigma = torch.transpose(sigma, 2, 1)
        
        self.P = self.prior_association(sigma)
        
        # S sums up to 1 in last dim
        self.Z, self.S = self.mha(Q,K,V)

        out = self.linear(self.Z)
        return torch.transpose(out,1,2)#torch.transpose(self.Z,1,2)

    def prior_association(self, sigma):
        """
        computing the density according to time difference
        """

        # [batch_size, n_heads, n_steps]
        batch_size = sigma.shape[0]
        n_heads = sigma.shape[1]
        seq_len = sigma.shape[2]
        
        # [batch_size, n_heads, n_steps, 1]
        sigma=  sigma.unsqueeze(-1).abs() + 1e-4 #if inputs are 0, sigma will contain 0s so add \epsilon

        # [batch_size, n_heads, n_steps, n_steps]
        #sigma = sigma.repeat([1, 1, 1, seq_len])
        sigma2 = torch.square(sigma).repeat([1, 1, 1, seq_len])
        
        # will be the mean value
        
        i_s = np.indices((seq_len,seq_len))[0]
        i_s = torch.from_numpy(i_s).float().to(self.device).unsqueeze(0).unsqueeze(0)
        i_s = i_s.expand([batch_size, n_heads, i_s.shape[-2], i_s.shape[-1]])

        j_s = np.indices((seq_len,seq_len))[1]
        j_s = torch.from_numpy(j_s).float().to(self.device).unsqueeze(0).unsqueeze(0)
        j_s = j_s.expand( [batch_size, n_heads, j_s.shape[-2], j_s.shape[-1]])
      
        #manually calculate density -- gradient through torch.distributions produces nan gradient for some reason. 
        #[batch_size, n_heads, n_steps, n_steps]
        density = 1/sigma * torch.exp(-((j_s-i_s)**2)/(2*sigma2))
        #density = torch.distributions.Normal(loc=i_s, scale=sigma).log_prob(j_s)
        #print(density)
        
        # normalizing rows 
        #[batch_size, n_heads, n_steps, 1]
        density /= density.sum(dim=-1).unsqueeze(-1)
        

        return density

    def association_discrepancy(self, phase = 'min'):
        if phase == 'min':
            P = self.P
            S = self.S.detach()
        if phase == 'max':
            P = self.P.detach()
            S = self.S
        
       
        # [batch_size, n_head, seq_len, seq_len]
        kldiv = (F.kl_div(P, S,reduction='none') + F.kl_div(S, P,reduction='none'))
        
        # [batch_size, n_head, seq_len]
        #kldiv = kldiv.mean(-1)
        if self.task=='prediction':
            kldiv = kldiv[:,:,-1]
        else:
            kldiv = kldiv.mean(-1)
       
        # reduce along the head axis
        # [batch_size, seq_len]
        kldiv = kldiv.mean(dim=1)
        
        return kldiv


class AnomalyTransformerBlock(nn.Module):
    def __init__(self, n_seqs, n_steps, emb_dim = 64, device = 'cuda:0', task='prediction'):
        super().__init__()
        self.n_seqs, self.n_steps = n_seqs, n_steps
        
        #self.attention = AnomalyAttention(self.n_seqs, n_steps, emb_dim = emb_dim, n_heads = 4, device=device) #emb_dim=n_steps because skip connections. n_heads must divide emb_dim
        self.attention = AnomalyAttention(emb_dim, n_steps, emb_dim = emb_dim, n_heads = 4, device=device) #emb_dim=n_steps because skip connections. n_heads must divide emb_dim
        
        self.ln1 = nn.LayerNorm(self.n_steps)
        
        self.ff = nn.Sequential(
            nn.Linear(self.n_steps, self.n_steps),
            nn.ReLU()
        )
        self.ln2 = nn.LayerNorm(self.n_steps)
        self.association_discrepancy = None
        self.device=device
        self.linear = nn.Linear(n_seqs,emb_dim)

    def forward(self, x, **kwargs):
        
        x=torch.transpose(x,1,2)
        x =self.linear(x)
        x=torch.transpose(x,1,2)
        
        x_identity = x 
        
        x = self.attention(x)
        
        z = self.ln1(x + x_identity)
        
        z_identity = z
        z = self.ff(z)
        
        z = self.ln2(z + z_identity)

        self.association_discrepancy = self.attention.association_discrepancy(phase = kwargs['phase'])
        
        return z

class AnomalyTransformer(nn.Module):
    def __init__(self, n_seqs, n_steps, n_layers=1, lambda_=3, task = 'prediction', output_dim=16,emb_dim=64,dropout_rate=0.2,**kwargs):
        super().__init__()
        self.task = task
        self.n_steps = n_steps
        self.device = kwargs['device']
        #print('Tr dev', self.device)
        dims = [n_seqs]+[emb_dim]*(n_layers-1)
        self.blocks = nn.ModuleList([
            AnomalyTransformerBlock(dim, n_steps,device = self.device, emb_dim=emb_dim, task=task) for dim in dims])#range(n_layers)
        #])
        
        self.predictor = nn.Linear(n_steps,output_dim)#MLP(n_seqs,n_steps=n_steps,node_info=kwargs['node_info'])#nn.Linear(n_steps,1)

        self.output = None
        self.lambda_ = lambda_
        self.assoc_discrepancy = torch.zeros((n_seqs, len(self.blocks)))
        self.n_seqs= n_seqs
        self.dropout_rate = dropout_rate
        self.phase = 'min'
    
    def forward(self, data, **kwargs):
        #[b,n_seqs,n_steps]
        x=data
       
        x=F.dropout(x,p=self.dropout_rate)

        self.assoc_discrepancy = torch.zeros((x.shape[0],self.n_steps, len(self.blocks))).float().to(self.device)
        for idx, block in enumerate(self.blocks):
            x = block(x, phase = self.phase)
           
            self.assoc_discrepancy[:,:, idx] = block.association_discrepancy
        
        self.assoc_discrepancy = self.assoc_discrepancy.sum(dim=-1) #N x 1
        
        data.x = x
        out= self.predictor(data)

        self.output = out

        return out

    def get_additional_loss_terms(self):
        self.switch_phase()
        return self.assoc_discrepancy_loss_term()

    def assoc_discrepancy_loss_term(self):
        sgn = 1.0 if self.phase=='max' else -1.0
        return (sgn*self.lambda_ * self.assoc_discrepancy.mean())

    def switch_phase(self):
        self.phase = 'min' if self.phase=='max' else 'max'


    def additional_anomaly_term(self):
        return self.assoc_discrepancy()
        
    def anomaly_score(self, x, edges, **kwargs):
        out = self(x,edges,**kwargs)
        score = F.mse_loss(out,x, reduction='none').mean(axis=-1).mean(axis=-1)
        
        return score

    def test_prediction(self,x, edges, **kwargs):
        pred = self(x, edges, **kwargs)
        
        return pred

