import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from utils import *
from dataloader import TimeSeriesLoader


class LSTMEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1):
        super(LSTMEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True)
    
    def forward(self, x):
        x = x.view(x.shape[0], x.shape[1], self.input_dim)
        out, self.hidden = self.lstm(x)
        return out, self.hidden


class LSTMDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1):
        super(LSTMDecoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
        self.linear = nn.Linear(hidden_dim, input_dim)

    def forward(self, x, encoder_hidden):
        lstm_out, hidden = self.lstm(x.unsqueeze(0), encoder_hidden)
        x_hat = self.linear(lstm_out.squeeze(0))     
        return x_hat, hidden


class LSTM_AE(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LSTM_AE, self).__init__()
        self.input_dim = input_dim
        self.encoder = LSTMEncoder(input_dim=input_dim, hidden_dim=hidden_dim)
        self.decoder = LSTMDecoder(input_dim=input_dim, hidden_dim=hidden_dim)

        self.criterion = nn.MSELoss()

    def forward(self, x):
        enc_out, h = self.encoder(x)
        dec_input = x[:, -1, :]
        x, _ = self.decoder(dec_input, h)
        return x

    def loss(self, x, mode='teacher_forcing', tf_ratio=0.5, reduction='mean'):
        x_hat = self.reconstruct(x, mode=mode, tf_ratio=tf_ratio)

        if reduction == 'mean':
            return self.criterion(x_hat, x)
        elif reduction == 'sum':
            return self.criterion(x_hat, x).sum()
        else:
            raise ValueError("Invalid reduction mode")

    def reconstruct(self, x, mode='teacher_forcing', tf_ratio=0.5):
        enc_out, enc_hidden = self.encoder(x)
        h = enc_hidden
        B, W, F = x.size()
        dec_input = x[:, -1, :]
        x_hat = torch.zeros(B, W, F, device=x.device)

        if mode == 'teacher_forcing':
            use_tf = np.random.rand() < tf_ratio
            for t in range(W):
                dec_out, h = self.decoder(dec_input, h)
                x_hat[:, t, :] = dec_out
                dec_input = x[:, t, :] if use_tf else dec_out

            return x_hat

        for t in range(W):
            dec_out, h = self.decoder(dec_input, h)
            x_hat[:, t, :] = dec_out
            if mode == 'recursive':
                dec_input = dec_out
            elif mode == 'mixed_teacher_forcing':
                if np.random.rand() < tf_ratio:
                    dec_input = x[:, t, :]
                else:
                    dec_input = dec_out
            else:
                raise ValueError("mode must be one of {'recursive','teacher_forcing','mixed_teacher_forcing'}")
        return x_hat
    
    @torch.no_grad()
    def predict_series(self, dataloader, data_type='test'):
        loader = dataloader.train_loader if data_type == 'train' else dataloader.test_loader
        self.eval()
        preds_windows = []

        for x in loader:
            B, W, F = x.size()

            _, hidden = self.encoder(x)

            dec_input = x[:, -1, :]
            out_win = x.new_zeros(B, W, F)

            for t in range(W):
                dec_output, hidden = self.decoder(dec_input, hidden)
                out_win[:, t, :] = dec_output
                dec_input = dec_output

            preds_windows.append(out_win)

        all_windows = torch.cat(preds_windows, dim=0).detach().cpu().numpy()
        return dataloader.unroll_windows(all_windows, data_type=data_type)
    
    def fit_error_distribution(self, x, x_hat, eps=1e-6):
        errors = np.abs(x - x_hat)
        self.error_mu = np.mean(errors, axis=0)
        cov = np.cov(errors, rowvar=False)
        cov = cov + np.eye(cov.shape[0]) * eps
        self.error_cov_inv = np.linalg.inv(cov)

    def anomaly_score(self, x, x_hat, method='mahalanobis', dataloader=None, data_type='train'):
        if method == 'naive':
            err = np.abs(x - x_hat).mean(axis=1)
            return 1.0 / (1.0 + np.exp(-err))

        elif method == 'mahalanobis':
            if not (hasattr(self, 'error_mu') and hasattr(self, 'error_cov_inv')):
                if dataloader is None:
                    raise RuntimeError("Error distribution not fitted and no dataloader provided to fit on train.")
                fit_true = dataloader.train_ds.data.reshape(-1, self.input_dim)
                fit_pred = self.predict_series(dataloader, data_type=data_type)
                self.fit_error_distribution(fit_true, fit_pred)

            errors = np.abs(x - x_hat)
            diff = errors - self.error_mu
            m_dist = np.sqrt(np.sum(diff.dot(self.error_cov_inv) * diff, axis=1))
            return m_dist

        else:
            raise ValueError("method must be one of {'naive','mahalanobis'}")


class LSTMAEAnomalyDetector(nn.Module):
    def __init__(self, dataloader, input_dim, hidden_dim, training_mode='teacher_forcing', tf_ratio=0.5, dynamic_tf=False, device=None):
        super().__init__()
        self.dataloader = dataloader
        self.training_mode = training_mode
        self.tf_ratio = tf_ratio
        self.dynamic_tf = dynamic_tf
        self.input_dim = input_dim

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = LSTM_AE(input_dim=input_dim, hidden_dim=hidden_dim).to(self.device)

    def fit(self, epochs=50, learning_rate=1e-4, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type=='train' else self.dataloader.test_loader
        dataset = self.dataloader.train_ds if data_type=='train' else self.dataloader.test_ds
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.model.train()

        pbar = tqdm(range(epochs), desc = 'Training LSTM-AE', leave=True)
        for epoch in pbar:
            total_loss = 0.0
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()

                loss = self.model.loss(x, mode=self.training_mode, tf_ratio=self.tf_ratio, reduction='mean')

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(dataloader.dataset)

            # dynamic teacher forcing
            if self.dynamic_tf and self.tf_ratio > 0:
                self.tf_ratio = self.tf_ratio - 0.02

            tqdm.write(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

        fit_true = dataset.data.reshape(-1, self.input_dim)
        fit_pred = self.model.predict_series(self.dataloader, data_type='train')
        self.model.fit_error_distribution(fit_true, fit_pred)

        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='lstmAE', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")
    
    @torch.no_grad()
    def predict_score(self, score_type='mahalanobis', data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        x = self.dataloader.train_ds.data.reshape(-1, self.input_dim) if data_type == 'train' else self.dataloader.test_ds.data.reshape(-1, self.input_dim)
        x_hat = self._predict(data_type=data_type)
        
        if score_type == 'naive':
            score = self.model.anomaly_score(x, x_hat, method=score_type)
        elif score_type == 'mahalanobis':
            score = self.model.anomaly_score(x, x_hat, method=score_type, dataloader=self.dataloader, data_type='train')
        else:
            raise ValueError("Invalid score_type. Choose 'naive' or 'mahalanobis'.")

        return score
    
    @torch.no_grad()
    def _predict(self, data_type='test'):
        """
        :param ts_loader:       TimeSeriesLoader
        :param target_dataloader: DataLoader for test windows
        :return:                (L, feat) numpy array of reconstructed series
        """
        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        self.model.eval()
        preds_windows = []

        for batch in dataloader:
            x = batch.to(self.device)
            B, W, F = x.size()

            _, hidden = self.model.encoder(x)

            decoder_input = x[:, -1, :]

            out_win = torch.zeros(B, W, F, device=self.device)
            for t in range(W):
                decoder_output, hidden = self.model.decoder(decoder_input, hidden)
                out_win[:, t, :] = decoder_output
                decoder_input = decoder_output

            preds_windows.append(out_win)

        all_windows = torch.cat(preds_windows, dim=0).cpu().numpy()
        return self.dataloader.unroll_windows(all_windows, data_type=data_type)



if __name__ == "__main__":
    set_seed(42)

    dataset_name = 'SMD'
    
    cfg = ModelConfig('lstmAE')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])

    input_dim = loader.train_ds.data.shape[1]

    model = LSTMAEAnomalyDetector(loader, input_dim,
                                  hidden_dim=model_config['hidden_dim'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              data_type="train")

    anomaly_score = model.predict_score(data_type='test')
    y_label = loader.test_ds.labels

    metrics = cal_metric(y_label, anomaly_score)
    print(f"Mahalanobis metric: {metrics}")