import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

from dataloader import TimeSeriesLoader
from utils import *


class LSTMVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, n_layers=1, noise_std=0.1):
        super().__init__()
        self.noise_std = noise_std

        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
        self.encoder_act  = nn.Softplus()

        self.fc_mu     = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        self.decoder_lstm = nn.LSTM(latent_dim, hidden_dim, n_layers, batch_first=True)
        self.decoder_act  = nn.Softplus()

        self.fc_out_mu    = nn.Linear(hidden_dim, input_dim)
        self.fc_out_sigma = nn.Sequential(nn.Linear(hidden_dim, input_dim),
                                          nn.Tanh())

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # 1) Denoising
        x_noisy = x + torch.randn_like(x) * self.noise_std

        # 2) Encode
        enc_out, _ = self.encoder_lstm(x_noisy)   
        enc_out    = self.encoder_act(enc_out)
        mu_z       = self.fc_mu(enc_out)
        logvar_z   = self.fc_logvar(enc_out)
        z          = self.reparameterize(mu_z, logvar_z)

        # 3) Decode
        dec_out, _ = self.decoder_lstm(z)         
        dec_out    = self.decoder_act(dec_out)
        mu_x       = self.fc_out_mu(dec_out)
        sigma_x    = self.fc_out_sigma(dec_out)

        return mu_x, sigma_x, mu_z, logvar_z

    def loss(self, x, mu_x, sigma_x, mu_z, logvar_z, kld_coef=0.1):
        var_x = sigma_x.pow(2) + 1e-8
        recon = 0.5 * (torch.log(var_x) + (x - mu_x).pow(2) / var_x + np.log(2*np.pi)).sum(dim=[1,2])
        kld   = -0.5 * (1 + logvar_z - mu_z.pow(2) - logvar_z.exp()).sum(dim=[1,2])

        recon_mean = recon.mean()
        kld_mean   = kld.mean()
        total      = recon_mean + kld_coef * kld_mean
        return total, recon_mean, kld_mean


class LSTMVAEAnomalyDetector(nn.Module):
    def __init__(self, dataloader, input_dim, hidden_dim, latent_dim, noise_std=0.1, n_layers=1, kld_coef=0.1, device=None):
        super().__init__()
        self.dataloader = dataloader
        self.kld_coef = kld_coef

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = LSTMVAE(input_dim, hidden_dim, latent_dim, n_layers, noise_std).to(self.device)

    def fit(self, epochs=50, learning_rate=1e-5, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.model.train()

        pbar = tqdm(range(epochs), desc='Training LSTM-VAE', leave=True)
        for epoch in pbar:
            total_loss, total_recon, total_kld = 0., 0., 0.
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()

                mu_x, sigma_x, mu_z, logvar_z = self.model(x)
                loss, recon, kld = self.model.loss(x, mu_x, sigma_x, mu_z, logvar_z, self.kld_coef)

                loss.backward()
                optimizer.step()
                total_loss += loss.item() * x.size(0)
                total_recon += recon.item() * x.size(0)
                total_kld += kld.item() * x.size(0)

            avg_total_loss = total_loss / len(dataloader.dataset)
            avg_total_recon = total_recon / len(dataloader.dataset)
            avg_total_kld = total_kld / len(dataloader.dataset)
            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_total_loss:.4f}, Recon: {avg_total_recon:.4f}, KLD: {avg_total_kld:.4f}')
        
        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='lstmVAE', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")

    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        self.model.eval()
        window_scores = []

        for x in dataloader:
            mu_x, sigma_x, _, _ = self.model(x)
            var_x = sigma_x.pow(2)

            rec = 0.5 * (torch.log(var_x) + (x - mu_x).pow(2) / var_x + np.log(2*np.pi))
            rec = rec.sum(dim=2)
            window_scores.append(rec.cpu().numpy())

        window_scores = np.concatenate(window_scores, axis=0)

        scores = self.dataloader.unroll_windows(windows = np.expand_dims(window_scores, axis=-1), data_type = 'test')  # shape = (test_length, 1)
        return scores.squeeze()




if __name__ == '__main__':
    set_seed(42)

    dataset_name = 'SMD'

    cfg = ModelConfig('lstmVAE')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])

    input_dim = loader.train_ds.data.shape[1]

    model = LSTMVAEAnomalyDetector(loader, input_dim,
                                   hidden_dim=model_config['hidden_dim'],
                                   latent_dim=model_config['latent_dim'],
                                   noise_std=model_config['noise_std'],
                                   n_layers=model_config['n_layers'],
                                   kld_coef=model_config['kld_coef'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              data_type='train')

    anomaly_score = model.predict_score(data_type='test')
    y_label = loader.test_labels

    metrics = cal_metric(y_label, anomaly_score)
    print(f"metrics: {metrics}")