import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import *
from dataloader import TimeSeriesLoader


class PlanarFlow(nn.Module):
    """
    A single planar flow, computes T(x) and log(det(jac_T)))
    """
    def __init__(self, D):
        super(PlanarFlow, self).__init__()
        self.u = nn.Parameter(torch.Tensor(1, D), requires_grad=True)
        self.w = nn.Parameter(torch.Tensor(1, D), requires_grad=True)
        self.b = nn.Parameter(torch.Tensor(1), requires_grad=True)
        self.h = torch.tanh
        self.init_params()

    def init_params(self):
        self.w.data.uniform_(-0.01, 0.01)
        self.b.data.uniform_(-0.01, 0.01)
        self.u.data.uniform_(-0.01, 0.01)

    def forward(self, z):
        linear_term = torch.mm(z, self.w.T) + self.b
        return z + self.u * self.h(linear_term)

    def h_prime(self, x):
        """
        Derivative of tanh
        """
        return (1 - self.h(x) ** 2)

    def psi(self, z):
        inner = torch.mm(z, self.w.T) + self.b
        return self.h_prime(inner) * self.w

    def log_det(self, z):
        inner = 1 + torch.mm(self.psi(z), self.u.T)
        return torch.log(torch.abs(inner))


class NormalizingFlow(nn.Module):
    """
    A normalizng flow composed of a sequence of planar flows.
    """
    def __init__(self, D, n_flows=2):
        super(NormalizingFlow, self).__init__()
        self.flows = nn.ModuleList([PlanarFlow(D) for _ in range(n_flows)])

    def sample(self, base_samples):
        """
        Transform samples from a simple base distribution
        by passing them through a sequence of Planar flows.
        """
        samples = base_samples
        for flow in self.flows:
            samples = flow(samples)
        return samples

    def forward(self, x):
        """
        Computes and returns the sum of log_det_jacobians
        and the transformed samples T(x).
        """
        sum_log_det = 0
        transformed_sample = x

        for i in range(len(self.flows)):
            log_det_i = (self.flows[i].log_det(transformed_sample))
            sum_log_det += log_det_i
            transformed_sample = self.flows[i](transformed_sample)

        return transformed_sample, sum_log_det


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, dense_dim, rolling_size, use_PNF, PNF_layers):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.use_PNF = use_PNF

        self.gru_cell = nn.GRUCell(input_size=input_dim, hidden_size=hidden_dim)

        if use_PNF:
            self.PNF = nn.ModuleList(NormalizingFlow(D=latent_dim, n_flows=PNF_layers) for _ in range(rolling_size))

        self.phi_enc = nn.Sequential(nn.Linear(hidden_dim + latent_dim, dense_dim),
                                     nn.ReLU())

        self.enc_means = nn.Sequential(nn.Linear(dense_dim, latent_dim))

        self.enc_stds = nn.Sequential(nn.Linear(dense_dim, latent_dim),
                                      nn.Softplus())

    def reparameterized(self, mean, std):
        """using std to sample"""
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

    def forward(self, x):
        h_i = torch.zeros(x.shape[0], self.hidden_dim, device=x.device)
        z_i = torch.zeros(x.shape[0], self.latent_dim, device=x.device)

        z_means, z_stds, zs = [], [], []
        sum_logabsdet = x.new_zeros(())

        for i in range(x.shape[1]):
            h_i = self.gru_cell(x[:, i], (h_i))
            h_z = torch.cat([h_i, z_i], dim=1)
            phi_z = self.phi_enc(h_z)
            z_mean = self.enc_means(phi_z)
            z_std = self.enc_stds(phi_z)
            z = self.reparameterized(mean=z_mean, std=z_std)

            if self.use_PNF:
                z, logdet = self.PNF[i](z)
                sum_logabsdet = sum_logabsdet + logdet

            zs.append(z)
            z_means.append(z_mean)
            z_stds.append(z_std)
            z_i = z 

        zs = torch.stack(zs, dim=1)
        z_means = torch.stack(z_means, dim=1)
        z_stds = torch.stack(z_stds, dim=1)
        return zs, z_means, z_stds, sum_logabsdet


class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, dense_dim):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim

        self.gru_cell = nn.GRUCell(input_size=latent_dim, hidden_size=hidden_dim)
        self.phi_dec = nn.Sequential(nn.Linear(hidden_dim, dense_dim),
                                     nn.ReLU())

        self.dec_means = nn.Sequential(nn.Linear(dense_dim, input_dim),
                                       nn.Sigmoid())

        self.dec_stds = nn.Sequential(nn.Linear(dense_dim, input_dim),
                                      nn.Softplus())

    def reparameterized(self, mean, std):
        """using std to sample"""
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

    def forward(self, x):
        h_i = torch.zeros(x.shape[0], self.hidden_dim, device=x.device)
        
        x_means, x_stds, xs = [], [], []

        for i in range(x.shape[1]):
            h_i = self.gru_cell(x[:, i], (h_i))
            phi_x = self.phi_dec(h_i)
            x_mean = self.dec_means(phi_x)
            x_std = self.dec_stds(phi_x)
            x_hat = self.reparameterized(mean=x_mean, std=x_std)
            
            x_means.append(x_mean)
            x_stds.append(x_std)
            xs.append(x_hat)

        x_hats = torch.stack(xs, dim=1)
        x_means = torch.stack(x_means, dim=1)
        x_stds = torch.stack(x_stds, dim=1)

        return x_hats, x_means, x_stds


class OmniAnomaly(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, rolling_size, use_PNF=True, PNF_layers=20):
        super().__init__()
        self.encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, dense_dim=hidden_dim, latent_dim=latent_dim,
                               rolling_size=rolling_size, use_PNF=use_PNF, PNF_layers=PNF_layers)
        self.decoder = Decoder(input_dim=input_dim, hidden_dim=hidden_dim, dense_dim=hidden_dim, latent_dim=latent_dim)

    def forward(self, x):
        z, z_mean, z_std, sum_logabsdet = self.encoder(x)
        x_hat, x_mean, x_std = self.decoder(z)

        return (z, z_mean, z_std, sum_logabsdet), (x_hat, x_mean, x_std)

    def loss(self, x, enc_out, dec_out):
        z, z_mean, z_std, sum_logabsdet = enc_out
        x_hat, x_mean, x_std = dec_out

        kld_loss = 0  # KL in ELBO
        nll_loss = 0  # -loglikihood in ELBO
        nll_loss = self.nll_gaussian(x, x_mean, x_std)

        kld_base = self.kld_diag_gaussian(z_mean, z_std)
        kld_loss = kld_base - sum_logabsdet.sum()

        loss = nll_loss + kld_loss
        return loss, nll_loss, kld_loss

    def nll_gaussian(self, x, x_mean, x_std):
        var = x_std.pow(2) + 1e-8
        return 0.5 * torch.sum(torch.log(2*torch.pi*var) + (x - x_mean).pow(2) / var)

    def kld_diag_gaussian(self, z_mean, z_std):
        var = z_std.pow(2) + 1e-6
        return 0.5 * torch.sum(z_mean.pow(2) + var - torch.log(var) - 1.0)


class OmniAnomalyDetector(nn.Module):
    def __init__(self, dataloader, input_dim, hidden_dim, latent_dim, use_PNF=True, PNF_layers=20, device=None):
        super().__init__()
        self.dataloader = dataloader
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = OmniAnomaly(input_dim, hidden_dim, latent_dim,
                                 rolling_size=self.dataloader.window_size, use_PNF=use_PNF, PNF_layers=PNF_layers).to(self.device)

    def fit(self, epochs=30, learning_rate=1e-4, weight_decay=1e-4, clip_norm=10.0, early_stopping=5, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        self.model.train()

        best_val, patience = float('inf'), 0
        pbar = tqdm(range(epochs), desc='Training OmniAnomaly', leave=True)
        for epoch in pbar:
            total_loss, total_nll_loss, total_kld_loss = 0.0, 0.0, 0.0
            ibar = tqdm(dataloader, desc='Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()

                enc_out, dec_out = self.model(x)
                loss, nll_loss, kld_loss = self.model.loss(x, enc_out, dec_out)
                
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), clip_norm)
                optimizer.step()

                total_loss += loss.item()
                total_nll_loss += nll_loss.item()
                total_kld_loss += kld_loss.item()

            avg_total_loss = total_loss / len(dataloader)
            avg_total_nll_loss = total_nll_loss / len(dataloader)
            avg_total_kld_loss = total_kld_loss / len(dataloader)

            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_total_loss:.4f}, NLL Loss: {avg_total_nll_loss:.4f}, KLD Loss: {avg_total_kld_loss:.4f}')

            if avg_total_loss < best_val:
                best_val, patience = avg_total_loss, 0
                if save or save_path:
                    if save_path is None:
                        save_path = build_save_path(model_name='OmniAnomaly', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
                    self.save, self.save_path = save, save_path
                    save_model(self.model, save_path)
                    tqdm.write(f"Model saved to {save_path}")
            else:
                patience += 1
                if patience >= early_stopping:
                    tqdm.write(f'Early stopping at epoch {epoch+1}, best loss: {best_val:.4f}')
                    break

    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        self.model.eval()
        scores = []

        for x in dataloader:
            (_, _, _, _), (batch_x_reconstruct, _, _) = self.model(x)
                
            error = torch.abs(batch_x_reconstruct - x).sum(dim=-1, keepdim=True)
            scores.append(error.cpu().numpy())

        scores = self.dataloader.unroll_windows(np.concatenate(scores), data_type=data_type)
        return scores.squeeze()


if __name__ == "__main__":
    set_seed(42)
    
    dataset_name = 'SMD'

    cfg = ModelConfig('OmniAnomaly')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])

    input_dim = loader.train_ds.data.shape[1]
    
    model = OmniAnomalyDetector(loader, input_dim,
                                hidden_dim=model_config['hidden_dim'],
                                latent_dim=model_config['latent_dim'],
                                use_PNF=model_config['use_PNF'],
                                PNF_layers=model_config['PNF_layers'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              weight_decay=train_config['weight_decay'],
              clip_norm=train_config['clip_norm'],
              early_stopping=train_config['early_stopping'],
              data_type='train')

    score = model.predict_score(data_type='test')
    y_label = loader.test_ds.labels

    metrics = cal_metric(y_label, score)
    print(f"metrics: {metrics}")