import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
from tqdm import tqdm
from dataloader import TimeSeriesLoader
from utils import *

import torch
import torch.nn as nn

class DeepSVDD(nn.Module):
    def __init__(self, input_dim, hidden_dim=[64, 32], dropout_rate=0.2):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate

        self.net =  self._build_model()
        self.register_buffer('c', torch.zeros(hidden_dim[-1]))

    def _build_model(self):
        L = len(self.hidden_dim)
        mods = []

        for i in range(L):
            in_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
            mods.append(nn.Linear(in_dim, self.hidden_dim[i], bias=False))
            if i < L - 1:
                mods.append(nn.ReLU(inplace=True))
                if self.dropout_rate and self.dropout_rate > 0.0:
                    mods.append(nn.Dropout(self.dropout_rate))

        return nn.Sequential(*mods)

    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        return self.net(x)
    
    def loss(self, outputs, c):
        dist = torch.sum((outputs - c)**2, dim=1)
        loss = torch.mean(dist)
        return loss

    @torch.no_grad()
    def init_c(self, dataloader, eps=0.1):
        self.eval()  # disable dropout
        device = next(self.parameters()).device

        n_samples = 0
        c = torch.zeros(torch.tensor(self.hidden_dim[-1]), device=device)

        for x in dataloader:
            z = self.forward(x)
            c += z.sum(dim=0)
            n_samples += z.size(0)

        c /= n_samples
        c[(abs(c) < eps) & (c < 0)] = -eps
        c[(abs(c) < eps) & (c > 0)] = eps

        self.c = c
        return self.c

class DeepSVDDAnomalyDetector():
    def __init__(self, dataloader, input_dim, hidden_dim=[64, 32], c=None, dropout_rate=0.2, device=None):
        self.dataloader = dataloader
        self.c = c

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = DeepSVDD(input_dim*self.dataloader.window_size, hidden_dim, dropout_rate).to(self.device)

    def fit(self, epochs=50,  learning_rate=1e-3, weight_decay=1e-4, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        optimizer = torch.optim.Adam(self.model.parameters(), weight_decay=weight_decay, lr=learning_rate)

        if self.c is None:
            self.c = self.model.init_c(dataloader)
        self.model.train()

        pbar = tqdm(range(epochs), desc='Training DeepSVDD', leave=True)
        for epoch in pbar:
            total_loss = 0.
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()

                output = self.model(x)
                loss = self.model.loss(output, self.c)

                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            avg_loss = total_loss / len(dataloader.dataset)
            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='DeepSVDD', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")
    
    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")
        
        dataloader = self.dataloader.test_loader if data_type == 'test' else self.dataloader.train_loader
        self.model.eval()
        window_scores = []

        if self.model.c is None or torch.all(self.model.c == 0):
            raise ValueError("Center vector c has not been initialized or loaded.")
        self.c = self.model.c

        for x in dataloader:
            outputs = self.model(x)
            dist = torch.sum((outputs - self.c)**2, dim=1)
            window_scores.append(dist.cpu().detach().numpy())

        window_scores = np.concatenate(window_scores, axis=0)

        win = self.dataloader.window_size
        window_scores = np.repeat(window_scores[:, None, None], win, axis=1)

        anomaly_scores = self.dataloader.unroll_windows(window_scores, data_type).squeeze(-1)

        return anomaly_scores



if __name__ == "__main__":        
    set_seed(42)

    dataset_name = 'SMD'

    cfg = ModelConfig('DeepSVDD')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    # initialize dataloader
    loader = TimeSeriesLoader(dataset_name=dataset_name, **loader_config)

    model_config['input_dim'] = loader.train_ds.data.shape[1]

    model = DeepSVDDAnomalyDetector(loader, **model_config)

    model.fit(**train_config, data_type='train')

    anomaly_score = model.predict_score(data_type='test')
    y_true = loader.test_labels

    metrics = cal_metric(y_true, anomaly_score)
    print(metrics)