import os
import copy
import time

import numpy as np
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score
from torch.utils.tensorboard import SummaryWriter


def f1_calculator(targets, score):
    # count how many positive sample
    n_possitive = sum(targets == 1)
    tp_plus_fp_idx = np.argsort(-score)[:n_possitive]
    tp_plus_fp = targets[tp_plus_fp_idx]

    tp = tp_plus_fp[tp_plus_fp == 1]
    fp = tp_plus_fp[tp_plus_fp == 0]

    f1 = len(tp) / (len(tp) + len(fp))
    return f1


#Trainer for PLAD
class PLADTrainer:

    def __init__(self, model, e_ae, optimizer, lamda, device):
        """Initialize PLAD Trainer

        Parameters
        ----------
        model: Torch neural network object
        optimizer: Total number of epochs for training.
        lamda: Weight of perturbator noise loss
        device: torch.device object for device to use.
        """     
        self.model = model
        self.e_ae = e_ae
        self.optimizer = optimizer
        self.lamda = lamda
        self.device = device

    def train(self, train_loader, total_epochs,
                ):
					
        """
        Training function
        Parameters
        ----------
        train_loader: Dataloader object for the training dataset.
        val_loader: Dataloader object for the validation dataset.
        learning_rate: Initial learning rate for training.
        total_epochs: Total epochs for training.
        metric: Metric used for evaluation (AUC / F1).
        """
        best_auc = -np.inf
        best_f1 = -np.inf
        best_model = None

        noise_generation_times = []
        epoch_train_times = []
        test_times = []
        for epoch in range(total_epochs): 
            self.e_ae.train()
            self.model.train()
  
            total_loss = 0
            batch_idx = -1
            epoch_generation_time = 0
            epoch_train_time_start = time.perf_counter()
            for data, target in train_loader:
                batch_idx += 1
                data, target = data.to(self.device), target.to(self.device)  
                
                # Data Processing
                data = data.to(torch.float)
                target = target.to(torch.float)
                target = torch.squeeze(target)

                #Noise produced from perturbator

                e_x = data.view(data.size(0), -1)
                
                x_hat1 = torch.zeros_like(e_x)
                x_hat2 = torch.ones_like(e_x)   
                             
                self.optimizer.zero_grad()
                start_time = time.perf_counter()
                e, z_mean, z_sigma = self.e_ae(e_x)
                            
                e1 = e[:,:e_x.size(1)]
                e2 = e[:,e_x.size(1):]

                data_t = data * e2 + e1
                
                #Noise loss
                # e_loss1 = F.mse_loss(e1, x_hat1)
                # e_loss2 = F.mse_loss(e2, x_hat2)

                e_loss1 = torch.norm(e1 - x_hat1)
                e_loss2 = torch.norm(e2 - x_hat2)
                # Kl-divergence of VAE
                kl_loss = self.latent_loss(z_mean, z_sigma)

                noise_loss = self.lamda * (e_loss2 + e_loss1) + kl_loss
                noise_loss.backward(retain_graph=True)


                noise_generation_time = time.perf_counter() - start_time
                epoch_generation_time += noise_generation_time
                
                # e1 = e1.view(-1,1,28,28)
                # e2 = e2.view(-1,1,28,28)

                #Produce anomalies

                
                #Cross-entropy loss of normal samples and anomalies
                logits1 = self.model(data) 
                logits1 = torch.squeeze(logits1, dim = 1)
                ce_loss1 = F.binary_cross_entropy_with_logits(logits1, target)
                logits2 = self.model(data_t)
                logits2 = torch.squeeze(logits2, dim = 1)
                ce_loss2 = F.binary_cross_entropy_with_logits(logits2, torch.ones_like(target))
                loss = ce_loss1 + ce_loss2  # + self.lamda * (e_loss2 + e_loss1) + kl_loss
                total_loss += loss + noise_loss

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(self.e_ae.parameters(), 1.0)
                self.optimizer.step()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
            #Average Cross-entropy loss 
            total_loss = total_loss/(batch_idx + 1)
            epoch_train_times.append(time.perf_counter() - epoch_train_time_start)

        #     noise_generation_times.append(epoch_generation_time)
        #     test_time_start = time.perf_counter()
        #     test_auc, test_f1 = self.test(val_loader, metric)
        #     if test_auc > best_auc:
        #         best_auc = test_auc
        #         best_model = copy.deepcopy(self.model)
        #     if test_f1 > best_f1:
        #         best_f1 = test_f1
        #         best_model = copy.deepcopy(self.model)
        #     test_times.append(time.perf_counter() - test_time_start)
        #     print('Epoch: {}, Loss: {}, {}: {}, {}: {}'.format(epoch, total_loss.item(), 'AUC', test_auc, 'F1', test_f1))
        #
        #     print('noise generation: ', np.array(noise_generation_times).mean(), np.array(noise_generation_times).std(),
        #           np.array(noise_generation_times).shape)
        #
        #     print('total train:', np.array(epoch_train_times).mean(), np.array(epoch_train_times).std(),
        #           np.array(epoch_train_times).shape)
        #     print('test time:', np.array(test_times).mean(), np.array(test_times).std(),
        #           np.array(test_times).shape)
        #
        # self.model = copy.deepcopy(best_model)
        # print('\nBest test {}: {}'.format(
        #     'AUC', best_auc, 'F1', best_f1
        # ))

    def test(self, test_loader):
        """Evaluate the model
        Parameters
        ----------
        test_loader: Dataloader object for the test dataset.
        metric: Metric used for evaluation (AUC / F1).
        """        
        self.model.eval()
        label_score = []
        batch_idx = -1
        with torch.no_grad():
            for data, target in test_loader:
                batch_idx += 1
                data, target = data.to(self.device), target.to(self.device)
                data = data.to(torch.float)
                target = target.to(torch.float)
                target = torch.squeeze(target)

                logits = self.model(data)
                logits = torch.squeeze(logits, dim = 1)
                sigmoid_logits = torch.sigmoid(logits)
                scores = logits
                label_score += list(zip(target.cpu().data.numpy().tolist(),
                                                scores.cpu().data.numpy().tolist()))
        # Compute test score
        labels, scores = zip(*label_score)
        labels = np.array(labels)
        scores = np.array(scores)

        # Evaluation based on https://openreview.net/forum?id=BJJLHbb0-
        # thresh = np.percentile(scores, 20)
        # y_pred = np.where(scores >= thresh, 1, 0)
        # prec, recall, f1, _ = precision_recall_fscore_support(
        #     labels, y_pred, average="binary")

        # f1, auc = 0, 0
        # # f1 = f1_calculator(labels, scores)
        # #
        # # auc = roc_auc_score(labels, scores)
        # test_metric = (auc, f1)
        return scores
        
    def latent_loss(self, z_mean, z_stddev):
        kl_divergence = 0.5 * torch.sum(torch.exp(z_stddev) + torch.pow(z_mean, 2) - 1. - z_stddev)
        return kl_divergence/z_mean.size(0)
          
    def save(self, path):
        torch.save(self.model.state_dict(),os.path.join(path, 'model.pt'))

    def load(self, path, filename):
        self.model.load_state_dict(torch.load(os.path.join(path, filename)))
