import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
from tqdm import tqdm
from dataloader import TimeSeriesLoader
from utils import *
import math
from math import sqrt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm


class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask


class AnomalyAttention(nn.Module):
    def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False):
        super(AnomalyAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        window_size = win_size
        self.distances = torch.zeros((window_size, window_size)).cuda()
        for i in range(window_size):
            for j in range(window_size):
                self.distances[i][j] = abs(i - j)

    def forward(self, queries, keys, values, sigma, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        attn = scale * scores

        sigma = sigma.transpose(1, 2)  # B L H ->  B H L
        window_size = attn.shape[-1]
        sigma = torch.sigmoid(sigma * 5) + 1e-5
        sigma = torch.pow(3, sigma) - 1
        sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size)  # B H L L
        prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).cuda()
        prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2))

        series = self.dropout(torch.softmax(attn, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", series, values)

        if self.output_attention:
            return (V.contiguous(), series, prior, sigma)
        else:
            return (V.contiguous(), None)


class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.norm = nn.LayerNorm(d_model)
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model,
                                          d_keys * n_heads)
        self.key_projection = nn.Linear(d_model,
                                        d_keys * n_heads)
        self.value_projection = nn.Linear(d_model,
                                          d_values * n_heads)
        self.sigma_projection = nn.Linear(d_model,
                                          n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)

        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads
        x = queries
        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)
        sigma = self.sigma_projection(x).view(B, L, H)

        out, series, prior, sigma = self.inner_attention(
            queries,
            keys,
            values,
            sigma,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), series, prior, sigma
    

class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, dropout=0.0):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x)


class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn, mask, sigma = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )
        x = x + self.dropout(new_x)
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn, mask, sigma


class Encoder(nn.Module):
    def __init__(self, attn_layers, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        series_list = []
        prior_list = []
        sigma_list = []
        for attn_layer in self.attn_layers:
            x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask)
            series_list.append(series)
            prior_list.append(prior)
            sigma_list.append(sigma)

        if self.norm is not None:
            x = self.norm(x)

        return x, series_list, prior_list, sigma_list


class AnomalyTransformer(nn.Module):
    def __init__(self, win_size, enc_in, c_out, d_model=512, n_heads=8, e_layers=3, d_ff=512,
                 dropout=0.0, activation='gelu', output_attention=True):
        super(AnomalyTransformer, self).__init__()
        self.output_attention = output_attention

        # Encoding
        self.embedding = DataEmbedding(enc_in, d_model, dropout)

        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention),
                        d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )

        self.projection = nn.Linear(d_model, c_out, bias=True)

    def forward(self, x):
        enc_out = self.embedding(x)
        enc_out, series, prior, sigmas = self.encoder(enc_out)
        enc_out = self.projection(enc_out)

        if self.output_attention:
            return enc_out, series, prior, sigmas
        else:
            return enc_out  # [B, L, D]
        
    def _my_kl_loss(self, p, q):
        res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001))
        return torch.mean(torch.sum(res, dim=-1), dim=1)
    
    def loss_function(self, x, x_recon, series, prior, window_size, type='train', temperature=None):
        series_loss = 0.0
        prior_loss = 0.0
        if type == 'train':
            for u in range(len(prior)):
                series_loss += (torch.mean(self._my_kl_loss(series[u], (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)).detach())) +
                                torch.mean(self._my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)).detach(), series[u])))
                prior_loss += (torch.mean(self._my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)), series[u].detach())) +
                            torch.mean(self._my_kl_loss(series[u].detach(), (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)))))
            series_loss = series_loss / len(prior)
            prior_loss = prior_loss / len(prior)

            rec_loss = F.mse_loss(x_recon, x)
            return rec_loss, (series_loss, prior_loss)
        else:
            rec_loss = F.mse_loss(x_recon, x, reduction='none').mean(dim=-1)
            for u in range(len(prior)):
                if u == 0:
                    series_loss = self._my_kl_loss(series[u], (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)).detach()) * temperature
                    prior_loss = self._my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)), series[u].detach()) * temperature
                else:
                    series_loss += self._my_kl_loss(series[u], (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)).detach()) * temperature
                    prior_loss += self._my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, window_size)), series[u].detach()) * temperature
            
            return rec_loss, (series_loss, prior_loss)
        

class AnomalyTransformerDetector:
    def __init__(self, dataloader, input_dim, hidden_dim, num_layers, device=None):
        self.dataloader = dataloader
        self.input_dim = input_dim

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = AnomalyTransformer(win_size=self.dataloader.window_size, enc_in=self.input_dim, c_out=self.input_dim,
                                        d_model=hidden_dim, e_layers=num_layers, d_ff=hidden_dim).to(self.device)

    def _adjust_learning_rate(self, optimizer, epoch, lr_):
        lr_adjust = {epoch: lr_ * (0.5 ** ((epoch - 1) // 1))}
        if epoch in lr_adjust.keys():
            lr = lr_adjust[epoch]
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Updating learning rate to {}'.format(lr))

    def fit(self, epochs=50, learning_rate=1e-5, early_stopping=5, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        self.model.train()

        best_loss1, best_loss2, patience = float('inf'), float('inf'), 0
        pbar = tqdm(range(epochs), desc='Training AnomalyTransformer', leave=True)
        for epoch in pbar:
            loss1_list = []
            loss2_list = []
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()
                output, series, prior, _ = self.model(x)
                rec_loss, (series_loss, prior_loss) = self.model.loss_function(x, output, series, prior, self.dataloader.window_size, type=data_type)

                loss1 = rec_loss - 3 * series_loss
                loss2 = rec_loss + 3 * prior_loss
                loss1_list.append(loss1.item())
                loss2_list.append(loss2.item())

                loss1.backward(retain_graph=True)
                loss2.backward()
                optimizer.step()
            
            train_loss1 = np.mean(loss1_list)
            train_loss2 = np.mean(loss2_list)

            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss 1: {train_loss1:.4f}, Loss 2: {train_loss2:.4f}')

            if train_loss1 < best_loss1 and train_loss2 < best_loss2:
                best_loss1, best_loss2, patience = train_loss1, train_loss2, 0
                if save or save_path:
                    if save_path is None:
                        save_path = build_save_path(model_name='AnomalyTransformer', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
                    self.save, self.save_path = save, save_path
                    save_model(self.model, save_path)
                    tqdm.write(f"Model saved to {save_path}")
            else:
                patience += 1
                if patience >= early_stopping:
                    tqdm.write(f'Early stopping at epoch {epoch+1}, best loss1: {best_loss1:.4f}, best loss2: {best_loss2:.4f}')
                    break
            
            # self._adjust_learning_rate(optimizer, epoch + 1, learning_rate)
    
    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        self.model.eval()

        attens_energy = []

        for x in dataloader:
            output, series, prior, _ = self.model(x)
            recon_loss, (series_loss, prior_loss) = self.model.loss_function(x, output, series, prior, self.dataloader.window_size, type=data_type, temperature=50.0)

            metric = torch.softmax((-series_loss - prior_loss), dim=-1)
            cri = metric * recon_loss
            cri = cri.detach().cpu().numpy()
            attens_energy.append(cri)

        attens_energy = np.concatenate(attens_energy)
        scores = self.dataloader.unroll_windows(attens_energy[:, :, None], data_type=data_type)
        return scores
    
if __name__ == "__main__":
    # Example usage
    set_seed(42)

    dataset_name = 'SMD'

    cfg = ModelConfig('AnomalyTransformer')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name, **loader_config)

    detector = AnomalyTransformerDetector(dataloader=loader,
                                          input_dim=loader.input_dim,
                                          hidden_dim=model_config['hidden_dim'],
                                          num_layers=model_config['num_layers'])
    
    detector.fit(epochs=train_config['epochs'],
                 learning_rate=train_config['learning_rate'],
                 early_stopping=train_config['early_stopping'],
                 save=False)
    
    scores = detector.predict_score(data_type='test')
    y_label = loader.test_ds.labels
    print("Anomaly scores:", scores)

    metrics = cal_metric(y_label, scores)
    print(f"metrics: {metrics}")