import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
from tqdm import tqdm
from dataloader import TimeSeriesLoader
from utils import *

import torch
import torch.nn as nn
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(input_dim, input_dim//2),
                                     nn.ReLU(),
                                     nn.Linear(input_dim//2, input_dim//4),
                                     nn.ReLU(),
                                     nn.Linear(input_dim//4, latent_dim),
                                     nn.ReLU())
        
    def forward(self, x):
        return self.encoder(x)
    
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.decoder = nn.Sequential(nn.Linear(latent_dim, latent_dim//2),
                                     nn.Linear(latent_dim//2, latent_dim//4),
                                     nn.Linear(latent_dim//4, output_dim),
                                     nn.ReLU(),
                                     nn.Sigmoid())
        
    def forward(self, z):
       return self.decoder(z)
    
class USAD(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder1 = Decoder(latent_dim, input_dim)
        self.decoder2 = Decoder(latent_dim, input_dim)

    def forward(self, x):
        z = self.encoder(x)
        self.w1 = self.decoder1(z)
        self.w2 = self.decoder2(z)
        self.w3 = self.decoder2(self.encoder(self.w1))
        return self.w1, self.w2, self.w3

    def loss(self, x, epoch):
        n = epoch
        loss_decoder1 = 1/n * torch.mean((x - self.w1)**2) + (1 - 1/n) * torch.mean((x - self.w3)**2)
        loss_decoder2 = 1/n * torch.mean((x - self.w2)**2) - (1 - 1/n) * torch.mean((x - self.w3)**2)
        return loss_decoder1, loss_decoder2


class USADAnomalyDetector:
    def __init__(self, dataloader, input_dim, latent_dim, alpha, beta, device=None):
        self.dataloader = dataloader
        self.alpha = alpha
        self.beta = beta

        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = USAD(input_dim, latent_dim).to(self.device)

    def fit(self, epochs=50, learning_rate=1e-3, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        optimizer1 = optim.Adam(list(self.model.encoder.parameters()) + list(self.model.decoder1.parameters()), lr=learning_rate)
        optimizer2 = optim.Adam(list(self.model.encoder.parameters()) + list(self.model.decoder2.parameters()), lr=learning_rate)
        self.model.train()
        pbar = tqdm(range(epochs), desc='Training USAD', leave=True)
        for epoch in pbar:
            total_loss, loss1, loss2 = 0., 0., 0.
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer1.zero_grad()
                optimizer2.zero_grad()

                w1, w2, w3 = self.model(x)
                loss_decoder1, _ = self.model.loss(x, epoch + 1)
                loss_decoder1.backward()
                optimizer1.step()

                w1, w2, w3 = self.model(x)
                _, loss_decoder2 = self.model.loss(x, epoch + 1)
                loss_decoder2.backward()
                optimizer2.step()

                loss1 += loss_decoder1.item()
                loss2 += loss_decoder2.item()
                total_loss += (loss_decoder1 + loss_decoder2).item()

            avg_loss1 = loss1 / len(dataloader.dataset)
            avg_loss2 = loss2 / len(dataloader.dataset)
            avg_total_loss = total_loss / len(dataloader.dataset)
            tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss1: {avg_loss1:.5f}, Loss2: {avg_loss2:.5f}, Total Loss: {avg_total_loss:.5f}')
        
        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='USAD', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")

    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        dataloader = self.dataloader.test_loader if data_type == 'test' else self.dataloader.train_loader
        self.model.eval()
        window_scores1 = []
        window_scores2 = []

        for x in dataloader:
            w1, w2, w3 = self.model(x)
            score1 = (x - w1)**2
            score2 = (x - w3)**2
            window_scores1.append(score1.cpu().detach().numpy())
            window_scores2.append(score2.cpu().detach().numpy())

        window_scores1 = np.concatenate(window_scores1, axis=0)
        window_scores2 = np.concatenate(window_scores2, axis=0)
        score1 = self.dataloader.unroll_windows(windows=window_scores1, data_type=data_type).mean(axis=1)
        score2 = self.dataloader.unroll_windows(windows=window_scores2, data_type=data_type).mean(axis=1)

        return self.alpha * score1 + self.beta * score2

if __name__ == "__main__":
    set_seed(42)

    dataset_name = 'SMD'

    cfg = ModelConfig('USAD')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])

    input_dim = loader.train_ds.data.shape[1]

    model = USADAnomalyDetector(loader, input_dim,
                                latent_dim=model_config['latent_dim'],
                                alpha=model_config['alpha'],
                                beta=model_config['beta'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              data_type='train')

    anomaly_score = model.predict_score(data_type='test')
    y_label = loader.test_labels

    metrics = cal_metric(y_label, anomaly_score)
    print(f"metrics: {metrics}")