import os
import copy
import numpy as np
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score
from torch.utils.tensorboard import SummaryWriter



#Trainer for PLAD
class PLADTrainer:

    def __init__(self, model, e_ae, optimizer, lamda, device):
        """Initialize PLAD Trainer

        Parameters
        ----------
        model: Torch neural network object
        optimizer: Total number of epochs for training.
        lamda: Weight of perturbator noise loss
        device: torch.device object for device to use.
        """     
        self.model = model
        self.e_ae = e_ae
        self.optimizer = optimizer
        self.lamda = lamda
        self.device = device

    def train(self, train_loader, val_loader, learning_rate, total_epochs, 
                metric='AUC'):
					
        """
        Training function
        Parameters
        ----------
        train_loader: Dataloader object for the training dataset.
        val_loader: Dataloader object for the validation dataset.
        learning_rate: Initial learning rate for training.
        total_epochs: Total epochs for training.
        metric: Metric used for evaluation (AUC / F1).
        """
        best_score = -np.inf
        best_model = None

        for epoch in range(total_epochs): 
            self.e_ae.train()
            self.model.train()
  
            total_loss = 0
            batch_idx = -1
            for data, target, _ in train_loader:
                batch_idx += 1
                data, target = data.to(self.device), target.to(self.device)  
                
                # Data Processing
                data = data.to(torch.float)
                target = target.to(torch.float)
                target = torch.squeeze(target)
        
                #Noise produced from perturbator
                e_x = data.view(data.size(0), -1)
                
                x_hat1 = torch.zeros_like(e_x)
                x_hat2 = torch.ones_like(e_x)   
                             
                self.optimizer.zero_grad()
                
                e, z_mean, z_sigma = self.e_ae(e_x)
                            
                e1 = e[:,:e_x.size(1)]
                e2 = e[:,e_x.size(1):]
                
                #Noise loss
                e_loss1 = F.mse_loss(e1, x_hat1)
                e_loss2 = F.mse_loss(e2, x_hat2) 
                      
                #Kl-divergence of VAE                     
                kl_loss = self.latent_loss(z_mean, z_sigma)
                
                e1 = e1.view(-1,1,28,28) 
                e2 = e2.view(-1,1,28,28)   

                #Produce anomalies
                data_t = data * e2 + e1
                
                #Cross-entropy loss of normal samples and anomalies
                logits1 = self.model(data) 
                logits1 = torch.squeeze(logits1, dim = 1)
                ce_loss1 = F.binary_cross_entropy_with_logits(logits1, target)
                logits2 = self.model(data_t)
                logits2 = torch.squeeze(logits2, dim = 1)
                ce_loss2 = F.binary_cross_entropy_with_logits(logits2, (target * 0))
                loss = ce_loss1 + ce_loss2 + self.lamda * (e_loss2 + e_loss1) + kl_loss
                total_loss += loss

                loss.backward()
                self.optimizer.step()
                
            #Average Cross-entropy loss 
            total_loss = total_loss/(batch_idx + 1)

            test_score = self.test(val_loader, metric)
            if test_score > best_score:
                best_score = test_score
                best_model = copy.deepcopy(self.model)
            print('Epoch: {}, Loss: {}, {}: {}'.format(epoch, total_loss.item(), metric, test_score))

        self.model = copy.deepcopy(best_model)
        print('\nBest test {}: {}'.format(
            metric, best_score    
        ))
        return best_score

    def test(self, test_loader, metric):
        """Evaluate the model
        Parameters
        ----------
        test_loader: Dataloader object for the test dataset.
        metric: Metric used for evaluation (AUC / F1).
        """        
        self.model.eval()
        label_score = []
        batch_idx = -1
        for data, target, _ in test_loader:
            batch_idx += 1
            data, target = data.to(self.device), target.to(self.device)
            data = data.to(torch.float)
            target = target.to(torch.float)
            target = torch.squeeze(target)

            logits = self.model(data)
            logits = torch.squeeze(logits, dim = 1)
            sigmoid_logits = torch.sigmoid(logits)
            scores = logits
            label_score += list(zip(target.cpu().data.numpy().tolist(),
                                            scores.cpu().data.numpy().tolist()))
        # Compute test score
        labels, scores = zip(*label_score)
        labels = np.array(labels)
        scores = np.array(scores)
        if metric == 'F1':
            # Evaluation based on https://openreview.net/forum?id=BJJLHbb0-
            thresh = np.percentile(scores, 20)
            y_pred = np.where(scores >= thresh, 1, 0)
            prec, recall, test_metric, _ = precision_recall_fscore_support(
                labels, y_pred, average="binary")
        if metric == 'AUC':
            test_metric = roc_auc_score(labels, scores)

        return test_metric
        
    def latent_loss(self, z_mean, z_stddev):
        kl_divergence = 0.5 * torch.sum(torch.exp(z_stddev) + torch.pow(z_mean, 2) - 1. - z_stddev)
        return kl_divergence/z_mean.size(0)
          
    def save(self, path):
        torch.save(self.model.state_dict(),os.path.join(path, 'model.pt'))

    def load(self, path, filename):
        self.model.load_state_dict(torch.load(os.path.join(path, filename)))
