import torch
import torch.nn as nn
import numpy as np
import os
import time
from utils.utils import EarlyStopping, Color, adjust_learning_rate, point_adjustment
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from data_factory.data_loader import get_loader_segment
from models.KANomaly import KANomaly
import pkbar
from torchinfo import summary


class Exp_AD(object):
    DEFAULTS = {}
    def __init__(self, config):
        self.__dict__.update(Exp_AD.DEFAULTS, **config)
        self.train_loader = get_loader_segment(self.data_path, 
                                               batch_size=self.batch_size, 
                                               win_size=self.win_size,
                                               mode='train',
                                               dataset=self.dataset)
        self.vali_loader = get_loader_segment(self.data_path, 
                                              batch_size=self.batch_size, 
                                              win_size=self.win_size,
                                              mode='val',
                                              dataset=self.dataset)
        self.test_loader = get_loader_segment(self.data_path, 
                                              batch_size=self.batch_size, 
                                              win_size=self.win_size,
                                              mode='test',
                                              dataset=self.dataset)
        self.build_model()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.criterion = nn.MSELoss()

    def build_model(self):
        self.model = KANomaly(win_size=self.win_size,  
                              num_chn=self.input_c,
                              patch_sizes=self.patch_sizes,
                              gridsize=self.gridsize,
                              num_layers=self.num_layers,
                              norm=self.norm,
                              drop=self.drop)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        if torch.cuda.is_available():
            self.model.cuda()

    def vali(self, vali_loader):
        self.model.eval()
        loss = []
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                for i, (input_data, _) in enumerate(vali_loader):
                    input = input_data.float().to(self.device)
                    output = self.model(input)
                    rec_loss = self.criterion(output, input)
                    loss.append(rec_loss.item())
        return np.average(loss)

    def train(self):
        print(Color.RED + "======================TRAIN MODE======================" + Color.RESET)
        input_data, _ = next(iter(self.train_loader))
        summary(self.model,
                input_data.shape, 
                col_width=20, 
                depth=4, 
                row_settings=('depth', 'var_names', 'hide_recursive_layers'), 
                col_names=["input_size", "kernel_size", "output_size", "num_params", "params_percent"], 
                verbose=1)
        path = self.model_save_path
        if not os.path.exists(path):
            os.makedirs(path)
        early_stopping = EarlyStopping(patience=3, verbose=True, dataset_name=self.dataset)
        train_steps = len(self.train_loader)
        scaler = torch.cuda.amp.GradScaler()
        for epoch in range(self.num_epochs):
            ######################### Initialization #########################
            kbar = pkbar.Kbar(target=train_steps,
                              epoch=epoch, 
                              num_epochs=self.num_epochs, 
                              width=40, 
                              always_stateful=True)
            ##################################################################
            loss_list = []
            epoch_time = time.time()
            self.model.train()
            for i, (input_data, labels) in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    input = input_data.float().to(self.device)
                    output = self.model(input)
                    rec_loss = self.criterion(output, input)
                    loss_list.append(rec_loss.item())
                    loss = rec_loss
                scaler.scale(loss).backward(retain_graph=True)
                scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1)
                scaler.step(self.optimizer)
                scaler.update()
                ############################# Update after each batch ##################################
                kbar.update(i, values=[("loss", loss)])
                ########################################################################################
            train_time = time.time() - epoch_time
            train_loss = np.average(loss_list)
            vali_loss = self.vali(self.vali_loader)
            ################################ Add validation metrics ################################### 
            kbar.add(1, values=[("Train Loss", train_loss), ("Val Loss", vali_loss)])
            ###########################################################################################.
            print("Epoch: {} cost time: {:.4f}".format(epoch + 1, train_time),"sec")
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(self.optimizer, epoch + 1, self.lr)

    def test(self):
        self.model.load_state_dict(
            torch.load(
                os.path.join(str(self.model_save_path), str(self.dataset) + '_checkpoint.pth')))
        self.model.eval()
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                print(Color.RED + "======================TEST MODE======================" + Color.RESET)
                criterion = nn.MSELoss(reduce=False)

                # (1) stastic on the train set
                ################################### Initialization ########################################
                print("stastic on the train set...")
                kbar1 = pkbar.Kbar(target=len(self.train_loader), width=40)
                ###########################################################################################
                attens_energy = []
                for i, (input_data, labels) in enumerate(self.train_loader):
                    input = input_data.float().to(self.device)
                    output = self.model(input)
                    cri = torch.mean(criterion(output, input), dim=-1)
                    cri = cri.detach().cpu().numpy()
                    attens_energy.append(cri)
                    kbar1.update(i)
                    kbar1.add(1)
                attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
                train_energy = np.array(attens_energy)

                # (2) find the threshold
                ################################### Initialization ########################################
                print("find the threshold...")
                kbar2 = pkbar.Kbar(target=len(self.vali_loader), width=40)
                ###########################################################################################
                attens_energy = []
                for i, (input_data, labels) in enumerate(self.vali_loader):
                    input = input_data.float().to(self.device)
                    output = self.model(input)
                    cri = torch.mean(criterion(output, input), dim=-1)
                    cri = cri.detach().cpu().numpy()
                    attens_energy.append(cri)
                    kbar2.update(i)
                    kbar2.add(1)
                attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
                test_energy = np.array(attens_energy)
                combined_energy = np.concatenate([train_energy, test_energy], axis=0)
                thresh = np.percentile(combined_energy, 100 - self.anomaly_ratio)
                print("Threshold :", thresh)

                # (3) evaluation on the test set
                ################################### Initialization ########################################
                print("evaluation on the test set...")
                kbar3 = pkbar.Kbar(target=len(self.test_loader), width=40)
                ###########################################################################################
                test_labels, attens_energy = [], []
                for i, (input_data, labels) in enumerate(self.test_loader):
                    input = input_data.float().to(self.device)
                    output = self.model(input)
                    cri = torch.mean(criterion(output, input), dim=-1)
                    cri = cri.detach().cpu().numpy()
                    attens_energy.append(cri)
                    test_labels.append(labels)
                    kbar3.update(i)
                    kbar3.add(1)
                attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
                test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
                test_energy = np.array(attens_energy)
                test_labels = np.array(test_labels)
                pred = (test_energy > thresh).astype(int)
                gt = test_labels.astype(int)

                # (4) detection adjustment
                gt, pred = point_adjustment(gt, pred)
                pred = np.array(pred)
                gt = np.array(gt)
                print("pred: ", pred.shape)
                print("gt:   ", gt.shape)
                accuracy = accuracy_score(gt, pred)
                precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
                print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
                        accuracy, precision, recall, f_score))
        return accuracy, precision, recall, f_score