import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
from tqdm import tqdm
from dataloader import TimeSeriesLoader
from utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F

class DAGMM(nn.Module):
    """Residual Block."""
    def __init__(self, input_dim, hidden_dim=64, n_gmm=2, latent_dim=3, device=None):
        super(DAGMM, self).__init__()

        assert latent_dim >= 3, "latent_dim must be >= 3 (encoder outputs latent_dim-2)."

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim, hidden_dim//2),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim//2, hidden_dim//4),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim//4, latent_dim-2))

        self.decoder = nn.Sequential(nn.Linear(latent_dim-2, hidden_dim//4),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim//4, hidden_dim//2),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim//2, hidden_dim),
                                     nn.Tanh(),
                                     nn.Linear(hidden_dim, input_dim))

        self.estimation = nn.Sequential(nn.Linear(latent_dim, hidden_dim//4),
                                        nn.Tanh(),
                                        nn.Dropout(p=0.5),
                                        nn.Linear(hidden_dim//4, n_gmm),
                                        nn.Softmax(dim=1))

        self.register_buffer("phi", torch.zeros(n_gmm))
        self.register_buffer("mu", torch.zeros(n_gmm, latent_dim))
        self.register_buffer("cov", torch.zeros(n_gmm, latent_dim, latent_dim))

    def forward(self, x):
        enc = self.encoder(x)
        x_hat = self.decoder(enc)

        rec_cosine = F.cosine_similarity(x, x_hat, dim=1)
        rec_euclidean = self._relative_euclidean_distance(x, x_hat)

        z = torch.cat([enc, rec_euclidean.unsqueeze(-1), rec_cosine.unsqueeze(-1)], dim=1)
        gamma = self.estimation(z)

        return enc, x_hat, z, gamma

    def loss(self, x, x_hat, z, gamma, lambda_energy, lambda_cov_diag):

        recon_error = F.mse_loss(x, x_hat)

        phi, mu, cov = self._compute_gmm_params(z, gamma)
        sample_energy, cov_diag = self._compute_energy(z, phi, mu, cov)

        loss = recon_error + lambda_energy * sample_energy + lambda_cov_diag * cov_diag

        return loss, sample_energy, recon_error, cov_diag

    def _relative_euclidean_distance(self, a, b):
        return (a-b).norm(2, dim=1) / (a.norm(2, dim=1) + 1e-14)

    def _compute_gmm_params(self, z, gamma):
        N = z.size(0)
        sum_gamma = torch.sum(gamma, dim=0) + 1e-14
        phi = sum_gamma / N

        # Means: [K x z_dim]
        mu = torch.sum(gamma.unsqueeze(-1) * z.unsqueeze(1), dim=0) / sum_gamma.unsqueeze(-1)

        # Covariances: [K x z_dim x z_dim]
        z_mu = z.unsqueeze(1) - mu.unsqueeze(0)
        cov = torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * (z_mu.unsqueeze(-1) * z_mu.unsqueeze(-2)), dim=0) / sum_gamma.unsqueeze(-1).unsqueeze(-1)

        # Update buffers
        self.phi.copy_(phi.detach())
        self.mu.copy_(mu.detach())
        self.cov.copy_(cov.detach())

        return phi, mu, cov

    def _compute_energy(self, z, phi=None, mu=None, cov=None, average=True):
        phi = self.phi if phi is None else phi
        mu = self.mu if mu is None else mu
        cov = self.cov if cov is None else cov

        K, D, _ = cov.shape
        eps = 1e-12

        z_mu = z.unsqueeze(1) - mu.unsqueeze(0)
        cov_k = cov + torch.eye(D, device=self.device).unsqueeze(0) * eps

        cov_inv = torch.linalg.inv(cov_k)
        det_cov = (2 * torch.pi) ** D * torch.linalg.det(cov_k)

        exp_term = torch.exp(-0.5 * torch.einsum('nkd,kde,nke->nk', z_mu, cov_inv, z_mu))

        weighted = phi.unsqueeze(0) * exp_term / torch.sqrt(det_cov).unsqueeze(0)
        energy = -torch.log(weighted.sum(dim=1) + eps)  # [N]
        if average:
            energy = energy.mean()

        cov_diag = cov_k.diagonal(dim1=-2, dim2=-1).reciprocal().sum()
        return energy, cov_diag
    

class DAGMMAnomalyDetector:
    def __init__(self, dataloader, input_dim, hidden_dim=64, n_gmm=2, latent_dim=3, lambda_energy=0.1, lambda_cov_diag=0.005, device=None):
        self.dataloader = dataloader
        self.lambda_energy = lambda_energy
        self.lambda_cov_diag = lambda_cov_diag

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = DAGMM(input_dim*self.dataloader.window_size, hidden_dim, n_gmm, latent_dim).to(self.device)
        
    def fit(self, epochs=50, learning_rate=1e-5, weight_decay=1e-4, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        optimizer = torch.optim.Adam(self.model.parameters(), weight_decay=weight_decay, lr=learning_rate)
        self.model.train()

        pbar = tqdm(range(epochs), desc='Training DAGMM', leave=True)
        for epoch in pbar:
            total_loss, total_recon_error, total_sample_energy, total_cov_diag = 0., 0., 0., 0.
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                x = x.view(x.size(0), -1)
                optimizer.zero_grad()

                enc, x_hat, z, gamma = self.model(x)
                loss, sample_energy, recon_error, cov_diag = self.model.loss(x, x_hat, z, gamma,
                                                                             lambda_energy=self.lambda_energy,
                                                                             lambda_cov_diag=self.lambda_cov_diag)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                total_recon_error += recon_error.item()
                total_sample_energy += sample_energy.item()
                total_cov_diag += cov_diag.item()
            
            avg_loss = total_loss / len(dataloader.dataset)
            avg_recon_error = total_recon_error / len(dataloader.dataset)
            avg_sample_energy = total_sample_energy / len(dataloader.dataset)
            avg_cov_diag = total_cov_diag / len(dataloader.dataset)
            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, '
                       f'Reconstruction Error: {avg_recon_error:.4f}, '
                       f'Sample Energy: {avg_sample_energy:.4f}, '
                       f'Covariance Diagonal: {avg_cov_diag:.4f}')
        
        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='DAGMM', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")

    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")
        
        dataloader = self.dataloader.test_loader if data_type == 'test' else self.dataloader.train_loader
        self.model.eval()
        window_scores = []

        for x in dataloader:
            x = x.view(x.size(0), -1)
            _, _, z, _ = self.model(x)
            sample_energy, _ = self.model._compute_energy(z, average=False)
            window_scores.append(sample_energy.cpu().detach().numpy())
        
        window_scores = np.concatenate(window_scores)
        
        win = self.dataloader.window_size
        windows_energy = np.repeat(window_scores[:, None, None], repeats=win, axis=1)

        scores = self.dataloader.unroll_windows(windows_energy, data_type)
        return scores


if __name__ == "__main__":
    # Example usage
    set_seed(42)

    dataset_name = 'SMD'

    cfg = ModelConfig('DAGMM')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    # initialize dataloader
    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])

    input_dim = loader.train_ds.data.shape[1]

    model = DAGMMAnomalyDetector(loader, input_dim,
                                 hidden_dim=model_config['hidden_dim'],
                                 latent_dim=model_config['latent_dim'],
                                 n_gmm=model_config['n_gmm'],
                                 lambda_energy=model_config['lambda_energy'],
                                 lambda_cov_diag=model_config['lambda_cov_diag'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              weight_decay=train_config['weight_decay'],
              data_type='train')

    anomaly_score = model.predict_score(data_type='test')
    y_true = loader.test_labels

    metrics = cal_metric(y_true, anomaly_score)
    print(metrics)