import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import *
from dataloader import TimeSeriesLoader


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.utils.parametrizations.weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, 
                                                                     stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1)

        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, input_dim, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = input_dim if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.TCN = nn.Sequential(*layers)

    def forward(self, x):
        return self.TCN(x)


class PlanarFlow(nn.Module):
    """
    A single planar flow, computes T(x) and log(det(jac_T)))
    """
    def __init__(self, D):
        super(PlanarFlow, self).__init__()
        self.u = nn.Parameter(torch.Tensor(1, D))
        self.w = nn.Parameter(torch.Tensor(1, D))
        self.b = nn.Parameter(torch.Tensor(1))
        self.h = torch.tanh
        self.init_params()

    def init_params(self):
        nn.init.uniform_(self.w, -0.01, 0.01)
        nn.init.uniform_(self.u, -0.01, 0.01)
        nn.init.uniform_(self.b, -0.01, 0.01)

    def forward(self, z):
        linear_term = torch.mm(z, self.w.T) + self.b
        return z + self.u * self.h(linear_term)

    def h_prime(self, x):
        """
        Derivative of tanh
        """
        return (1 - self.h(x) ** 2)

    def psi(self, z):
        inner = torch.mm(z, self.w.T) + self.b
        return self.h_prime(inner) * self.w

    def log_det(self, z):
        inner = 1 + torch.mm(self.psi(z), self.u.T)
        return torch.log(torch.abs(inner))


class NormalizingFlow(nn.Module):
    """
    A normalizng flow composed of a sequence of planar flows.
    """
    def __init__(self, D, n_flows=2):
        super(NormalizingFlow, self).__init__()
        self.flows = nn.ModuleList([PlanarFlow(D) for _ in range(n_flows)])

    def forward(self, z):  # z: (B, D)
        sum_logdet = 0.
        for f in self.flows:
            sum_logdet = sum_logdet + f.log_det(z)  # (B,1)
            z = f(z)
        return z, sum_logdet


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, dense_dim, latent_dim, tcn_levels, rolling_size, kernel_size, dropout, use_PNF, PNF_layers):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dense_dim = dense_dim
        self.latent_dim = latent_dim
        self.rolling_size = rolling_size

        self.use_PNF = use_PNF
        self.PNF_layers = PNF_layers
        self.TCN = TemporalConvNet(input_dim=self.input_dim,
                                   num_channels=[self.hidden_dim] * tcn_levels,
                                   kernel_size=kernel_size, dropout=dropout)

        if self.use_PNF:
            self.PNF = nn.ModuleList([NormalizingFlow(D=self.latent_dim, n_flows=self.PNF_layers) for _ in range(self.rolling_size)])

        self.phi_enc = nn.Sequential(nn.Linear(self.hidden_dim, self.dense_dim),
                                     nn.ReLU())

        self.enc_means = nn.Sequential(nn.Linear(self.dense_dim, self.latent_dim))

        self.enc_stds = nn.Sequential(nn.Linear(self.dense_dim, self.latent_dim),
                                      nn.Softplus())

    def reparameterized(self, mean, std):
        eps = torch.randn_like(std)
        return mean + std * eps

    def forward(self, x):
        B, T, _ = x.size()
        H = self.TCN(x.permute(0,2,1)).permute(0,2,1)
        H = self.phi_enc(H)

        mu_z  = self.enc_means(H)  # (B,T,z)
        std_z = self.enc_stds(H)   # (B,T,z)
        z0_seq = self.reparameterized(mu_z, std_z)  # (B,T,z)

        if not self.use_PNF:
            sum_logdet = x.new_zeros(B)
            return z0_seq, z0_seq, mu_z, std_z, sum_logdet
        
        zK_list, logdet_list = [], []
        
        for t in range(T):
            zK_t, logdet_t = self.PNF[t](z0_seq[:, t, :])  # (B,z), (B,1)
            zK_list.append(zK_t.unsqueeze(1))
            logdet_list.append(logdet_t) 

        zK_seq = torch.cat(zK_list, dim=1)                 # (B,T,z)
        sum_logdet = torch.cat(logdet_list, dim=1).sum(dim=1)  # (B,)

        return z0_seq, zK_seq, mu_z, std_z, sum_logdet


class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, dense_dim, tcn_levels, kernel_size, dropout):
        super(Decoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dense_dim = dense_dim
        self.latent_dim = latent_dim

        self.TCN = TemporalConvNet(input_dim=self.latent_dim,
                                   num_channels=[self.hidden_dim] * tcn_levels,
                                   kernel_size=kernel_size, dropout=dropout)

        self.phi_dec = nn.Sequential(nn.Linear(self.hidden_dim, self.dense_dim),
                                     nn.ReLU())

        self.dec_means = nn.Sequential(nn.Linear(self.dense_dim, self.input_dim),
                                       nn.Sigmoid())

        self.dec_stds = nn.Sequential(nn.Linear(self.dense_dim, self.input_dim),
                                      nn.Softplus())

    def reparameterized(self, mu, std):
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, z):
        H = self.TCN(z.permute(0,2,1)).permute(0,2,1)  # (B,T,h_dim)
        H = self.phi_dec(H)                                # (B,T,dense_dim)
        mu_x  = self.dec_means(H)
        
        std_x = self.dec_stds(H)

        return mu_x, std_x


class LUAD(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, rolling_size,
                 use_PNF=True, PNF_layers=2, tcn_levels=3, kernel_size=7, dropout=0.2):
        super().__init__()
        
        self.encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,
                               dense_dim=hidden_dim, rolling_size=rolling_size,
                               use_PNF=use_PNF, PNF_layers=PNF_layers, tcn_levels=tcn_levels, kernel_size=kernel_size, dropout=dropout)
        self.decoder = Decoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim,
                               dense_dim=hidden_dim, tcn_levels=tcn_levels, kernel_size=kernel_size, dropout=dropout)
        
        self.log_two_pi = torch.log(torch.tensor(2*torch.pi)).item()

    def forward(self, x):
        z0_seq, zK_seq, mu_z, std_z, sum_logdet = self.encoder(x)
        mu_x, std_x = self.decoder(zK_seq)
        return z0_seq, zK_seq, mu_z, std_z, sum_logdet, mu_x, std_x

    def loss(self, x):
        z0, zK, mu_z, std_z, sum_logdet, mu_x, std_x = self.forward(x)
        
        rec_nll = 0.5 * (((x - mu_x)**2) / (std_x**2 + 1e-12) + 2*torch.log(std_x + 1e-12) + self.log_two_pi)
        kl_t = 0.5 * (mu_z**2 + std_z**2 - 1.0 - 2.0*torch.log(std_z + 1e-12))

        rec_loss = rec_nll.sum(dim=(1,2)).mean()
        kl_loss = kl_t.sum(dim=(1,2)).mean()

        if sum_logdet is not None:
            kl_loss = kl_loss - sum_logdet.mean()

        elbo = rec_loss + kl_loss  # minimize -ELBO
        return elbo, rec_loss, kl_loss

    @torch.no_grad()
    def anomaly_score(self, x, reduce_feat='mean'):
        *_, mu_x, std_x = self.forward(x)
        nll = 0.5 * (((x - mu_x)**2) / (std_x**2 + 1e-12) + 2*torch.log(std_x + 1e-12) + self.log_two_pi)
        if reduce_feat == 'mean':
            wscore = nll.mean(dim=2, keepdim=True)
        else:
            wscore = nll.sum(dim=2, keepdim=True)
        return wscore


class LUADAnomalyDetector:
    def __init__(self, dataloader, input_dim, hidden_dim, latent_dim,
                 use_PNF=True, PNF_layers=2, tcn_levels=3, kernel_size=7, dropout=0.2, device=None):
        self.dataloader = dataloader
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = LUAD(input_dim, hidden_dim, latent_dim, rolling_size=self.dataloader.window_size,
                          use_PNF=use_PNF, PNF_layers=PNF_layers, tcn_levels=tcn_levels, kernel_size=kernel_size, dropout=dropout).to(self.device)

    def fit(self, epochs=50, learning_rate=1e-5, data_type='train', save=False, save_path=None):
        dataloader = self.dataloader.train_loader if data_type == 'train' else self.dataloader.test_loader
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        self.model.train()

        pbar = tqdm(range(epochs), desc='Training LUAD', leave=True)
        for epoch in pbar:
            total_loss, total_rec_loss, total_kl_loss = 0.0, 0.0, 0.0
            ibar = tqdm(dataloader, desc=f'Inner loop', leave=False)
            for x in ibar:
                optimizer.zero_grad()

                loss, rec_loss, kl_loss = self.model.loss(x)

                loss.backward()
                optimizer.step()
                total_loss += loss.item(); total_rec_loss += rec_loss.item(); total_kl_loss += kl_loss.item()

            avg_loss = total_loss / len(dataloader)
            avg_rec_loss = total_rec_loss / len(dataloader)
            avg_kl_loss = total_kl_loss / len(dataloader)
            tqdm.write(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.5f}, Reconstruction Loss: {avg_rec_loss:.5f}, KL Loss: {avg_kl_loss:.5f}")

        if save or save_path:
            if save_path is None:
                save_path = build_save_path(model_name='LUAD', dataset_name=self.dataloader.dataset_name, seed=get_global_seed())
            self.save, self.save_path = save, save_path
            save_model(self.model, save_path)
            tqdm.write(f"Model saved to {save_path}")

    @torch.no_grad()
    def predict_score(self, data_type='test', load_path=None):
        path = None
        if load_path is not None:
            path = load_path
        elif getattr(self, 'save_path', None):
            path = self.save_path

        if path:
            try:
                load_model(self.model, path)
                self.model.to(self.device)
                tqdm.write(f"Model loaded from {path}")
            except FileNotFoundError:
                tqdm.write(f"[warn] load_path not found: {path} — using in-memory model.")

        dataloader = self.dataloader.test_loader if data_type == 'test' else self.dataloader.train_loader
        self.model.eval()
        window_scores = []

        for x in dataloader:
            scores = self.model.anomaly_score(x)
            window_scores.append(scores.cpu().numpy())

        window_scores = np.concatenate(window_scores, axis=0)

        sample_scores = self.dataloader.unroll_windows(window_scores, data_type=data_type)
        return sample_scores

if __name__ == "__main__":
    set_seed(42)
    
    dataset_name = 'SMD'

    cfg = ModelConfig('LUAD')
    loader_config, model_config, train_config = cfg.resolve(dataset_name)

    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=loader_config['window_size'],
                              step_size=loader_config['step_size'],
                              batch_size=loader_config['batch_size'])
    
    input_dim = loader.train_ds.data.shape[1]

    model = LUADAnomalyDetector(dataloader=loader, input_dim=input_dim,
                                hidden_dim=model_config['hidden_dim'],
                                latent_dim=model_config['latent_dim'],
                                use_PNF=model_config['use_PNF'],
                                PNF_layers=model_config['PNF_layers'],
                                tcn_levels=model_config['tcn_levels'],
                                kernel_size=model_config['kernel_size'],
                                dropout=model_config['dropout'])

    model.fit(epochs=train_config['epochs'],
              learning_rate=train_config['learning_rate'],
              data_type='train')
    anomaly_score = model.predict_score(data_type='test')
    y_label = loader.test_labels

    metrics = cal_metric(y_label, anomaly_score)
    print(f"metrics: {metrics}")
