import os
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from utils.settings import *
import torch.optim as optim
import torch.nn.functional as F
import utils.settings as settings
from utils.data_loader import get_loader_segment
from model.Transformer import Transformer

def my_kl_loss(p, q):
    """
    Calculate DKL loss
    """
    res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001))
    return torch.mean(torch.sum(res, dim=-1), dim=1)

class EarlyStopping:
    """  
    Early stop when training is unreliable
    Reference:
        Xu, J.; Wu, H.; Wang, J.; and Long, M. 2022. Anomaly
        Transformer: Time Series Anomaly Detection with Associ
        ation Discrepancy. In International Conference on Learning
        Representations.
    """
    def __init__(self, patience=7, verbose=False, dataset_name='', delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.loss_min = np.Inf
        self.delta = delta
        self.dataset = dataset_name

    def __call__(self, loss, model, path):
        if self.best_loss is None:
            self.best_loss = loss
            self.save_checkpoint(loss, model, path)
        elif loss > self.best_loss + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.save_checkpoint(loss, model, path)
            self.counter = 0

    def save_checkpoint(self, loss, model, path):
        if self.verbose:
            print(f'Loss decreased ({self.loss_min:.6f} --> {loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), os.path.join(path, 'checkpoint.pth'))
        self.loss_min = loss

class Trainer:
    DEFAULTS = {}
    def __init__(self, config):
        self.__dict__.update(Trainer.DEFAULTS, **config)
        self.train_loader = get_loader_segment(settings.DATA_PATH, batch_size=self.batch_size, win_size=self.win_size,
                                               mode='train',
                                               dataset=self.dataset)
        self.test_loader = get_loader_segment(settings.DATA_PATH, batch_size=self.batch_size, win_size=self.win_size,
                                              mode='test',
                                              dataset=self.dataset)
        self.valid_loader = get_loader_segment(settings.DATA_PATH, batch_size=self.batch_size, win_size=self.win_size,
                                              mode='valid',
                                              dataset=self.dataset)
        self._bulid_model()
        self.path = os.path.join(BASE_PATH, self.model_save_path, self.dataset)
        self.MSELoss = torch.nn.MSELoss(reduction='mean')
        self.KLDivLoss = torch.nn.KLDivLoss(reduction='batchmean')

    def _bulid_model(self):
        self.model = Transformer(mask=False, win_size=self.win_size, enc_in=self.input_c, c_out=self.output_c, e_layers=3).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95)

    def cal_loss(self, out:tuple, target):
        """
        Calculate loss -- PTM recloss FRM recloss and DKL loss
        """
        recon_PTM, recon_FRM, attn_grah_PTM, attn_grah_FRM = out
        loss_ptm = self.MSELoss(recon_PTM, target)
        loss_frm = self.MSELoss(recon_FRM, target)
        kl_ptm = 0.0
        kl_frm = 0.0
        # Use Adversarial Learning
        for i in range(len(attn_grah_PTM)):
            kl_frm += torch.mean(my_kl_loss(attn_grah_PTM[i].detach(), attn_grah_FRM[i])) + torch.mean(my_kl_loss(attn_grah_FRM[i], attn_grah_PTM[i].detach()))
            kl_ptm += torch.mean(my_kl_loss(attn_grah_PTM[i], attn_grah_FRM[i].detach())) + torch.mean(my_kl_loss(attn_grah_FRM[i].detach(), attn_grah_PTM[i]))
        loss_t = loss_ptm + loss_frm + (kl_frm - kl_ptm)
        return loss_t, kl_ptm

    def cal_score(self, out:tuple, target):
        MSELoss = torch.nn.MSELoss(reduction='none')
        recon_PTM, recon_FRM, attn_grah_PTM, attn_grah_FRM = out
        s_p = MSELoss(recon_PTM, target).mean(-1)
        s_f = MSELoss(recon_FRM, target).mean(-1)
        s_c = 0.0
        temp = self.temp
        for i in range(len(attn_grah_PTM)):
            s_c += ((my_kl_loss(attn_grah_PTM[i], attn_grah_FRM[i])) + (my_kl_loss(attn_grah_FRM[i], attn_grah_PTM[i])))*temp
        score = torch.softmax(s_p*s_f*s_c, dim=-1)
        # return label-based score and ori score
        return score, s_p*s_f*s_c
            
    def train(self):
        print("======================TRAIN MODE======================")
        early_stopping = EarlyStopping(patience=5, verbose=True, dataset_name=self.dataset)
        if not os.path.exists(self.path):
            os.makedirs(self.path)
        for epoch in range(self.num_epochs):
            loss_list, loss_kl_list = [], []
            self.model.train()
            for (data, _) in tqdm(self.train_loader):
                self.optimizer.zero_grad()
                data = data.float().to(self.device)
                out = self.model(data)
                loss_t, loss_kl = self.cal_loss(out, data)
                loss_t.backward()
                self.optimizer.step()
                loss_list.append(loss_t.item())
                loss_kl_list.append(loss_kl.item())
            self.scheduler.step()
            loss_mean = np.average(loss_list)
            loss_kl_mean = np.average(loss_kl_list)
            print(
                "Epoch: {0} | Recon Loss: {1:.5f} KL Loss:{2:.5f}".format(
                    epoch + 1, loss_mean, loss_kl_mean))
            early_stopping(loss_mean, self.model, self.path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
    
    def test(self):
        self.model.load_state_dict(
            torch.load(
                os.path.join(self.path, 'checkpoint.pth')))
        self.model.eval()
        print("======================TEST MODE======================")
        # (1) stastic on the train set
        attens_energy = []
        with torch.no_grad():
            for i, (data, labels) in enumerate(self.train_loader):
                data = data.float().to(self.device)
                out = self.model(data)
                score, score2 = self.cal_score(out, data)
                score = score.detach().cpu().numpy()
                attens_energy.append(score)
        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        train_energy = np.array(attens_energy)

        # (2) find the threshold
        attens_energy = []
        with torch.no_grad():
            for i, (data, labels) in enumerate(self.valid_loader):
                data = data.float().to(self.device)
                out = self.model(data)
                score, _ = self.cal_score(out, data)
                score = score.detach().cpu().numpy()
                attens_energy.append(score)
        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        test_energy = np.array(attens_energy)
        combined_energy = np.concatenate([train_energy, test_energy], axis=0)
        thresh = np.percentile(combined_energy, 100 - self.anormly_ratio)
        print("Threshold :", thresh)

        # (3) evaluation on the test set
        attens_energy, test_labels, attens_energy_t = [], [], []
        with torch.no_grad():
            for i, (data, labels) in enumerate(self.test_loader):
                data = data.float().to(self.device)
                out = self.model(data)
                score, score2 = self.cal_score(out, data)
                score = score.detach().cpu().numpy()
                score2 = score2.detach().cpu().numpy()
                attens_energy.append(score)
                attens_energy_t.append(score2)
                test_labels.append(labels)
        attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
        attens_energy_t = np.concatenate(attens_energy_t, axis=0).reshape(-1)
        test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
        test_energy = np.array(attens_energy)
        test_energy_t = np.array(attens_energy_t)
        test_labels = np.array(test_labels)
        

        pred = (test_energy > thresh).astype(int)
        gt = test_labels.astype(int)
        print("pred:   ", pred.shape)
        print("gt:     ", gt.shape)
        anomaly_state = False
        for i in range(len(gt)):
            if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
                anomaly_state = True
                for j in range(i, 0, -1):
                    if gt[j] == 0:
                        break
                    else:
                        if pred[j] == 0:
                            pred[j] = 1
                for j in range(i, len(gt)):
                    if gt[j] == 0:
                        break
                    else:
                        if pred[j] == 0:
                            pred[j] = 1
            elif gt[i] == 0:
                anomaly_state = False
            if anomaly_state:
                pred[i] = 1
        pred = np.array(pred)
        gt = np.array(gt)
        print("pred: ", pred.shape)
        print("gt:   ", gt.shape)

        from sklearn.metrics import precision_recall_fscore_support
        from sklearn.metrics import accuracy_score
        from sklearn.metrics import average_precision_score, roc_auc_score
        accuracy = accuracy_score(gt, pred)
        precision, recall, f_score, support = precision_recall_fscore_support(gt, pred,
                                                                              average='binary')
        auc_score = roc_auc_score(gt, test_energy_t)
        pr_score = average_precision_score(gt, test_energy_t)
        print(
            "Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} , AUC-ROC : {:0.4f}, AUC-PR : {:0.4f}".format(
                accuracy, precision,
                recall, f_score, auc_score, pr_score))

        return accuracy, precision, recall, f_score