from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
from sklearn.metrics import roc_auc_score

class Flow_Trainer():      
    def __init__(self, model, epochs, train_loader, test_normal_loader, test_anomaly_loader,  optimizer, scheduler, device,):
        self.model = model
        self.epochs = epochs
        self.train_loader = train_loader
        self.test_normal_loader = test_normal_loader
        self.test_anomaly_loader = test_anomaly_loader
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler

    def fit(self):
        best_auc = 0
        best_epoch =0
        cnt = 0
        for i in range(self.epochs):
            self.model.train()
            mean_likelihood = 0.0
            num_minibatches = 0

            for x, y in self.train_loader:
                x = x.float().to(self.device)
                z, likelihood = self.model(x)
                loss = -torch.mean(likelihood)   # NLL
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                mean_likelihood -= loss
                num_minibatches += 1
            self.scheduler.step()
            mean_likelihood /= num_minibatches
            

    def valid(self):
        train_each_ll_list = []
        test_each_normal_ll_list = []
        test_each_anomaly_ll_list = []  

        self.model.eval()
        with torch.no_grad():
            mean_likelihood = 0.0
            num_minibatches = 0
            for x, y in self.train_loader:
                x = x.float().to(self.device)
                z, likelihood = self.model(x)
                loss = -torch.mean(likelihood)   # NLL
                mean_likelihood -= loss
                num_minibatches += 1
                train_each_ll_list.extend(likelihood.cpu().detach().tolist())

            mean_likelihood /= num_minibatches
            #train_ll_list.append(mean_likelihood.cpu().detach().item())
            
            test_normal_mean_likelihood = 0.0
            num_minibatches = 0
            for x, y in self.test_normal_loader:
                x = x.float().to(self.device)
                z, likelihood = self.model(x)
                loss = -torch.mean(likelihood)   # NLL
                test_normal_mean_likelihood -= loss
                num_minibatches += 1
                test_each_normal_ll_list.extend(likelihood.cpu().detach().tolist())

            test_normal_mean_likelihood /= num_minibatches
            #test_normal_ll_list.append(test_normal_mean_likelihood.cpu().detach().item())
            #print(f'Test completed. Normal Data Log Likelihood: {test_normal_mean_likelihood}')


            test_anomaly_mean_likelihood = 0.0
            num_minibatches = 0
            for x, y in self.test_anomaly_loader:
                x = x.float().to(self.device)
            
                z, likelihood =self.model(x)
                loss = -torch.mean(likelihood)   # NLL

                test_anomaly_mean_likelihood -= loss
                num_minibatches += 1
                
                test_each_anomaly_ll_list.extend(likelihood.cpu().detach().tolist())

            test_anomaly_mean_likelihood /= num_minibatches
            #test_anomaly_ll_list.append(test_anomaly_mean_likelihood.cpu().detach().item())
        #return train_ll_list, test_normal_ll_list, test_anomaly_ll_list, train_each_ll_list, test_each_normal_ll_list, test_each_anomaly_ll_list
        return train_each_ll_list, test_each_normal_ll_list, test_each_anomaly_ll_list
    
    def extract_latent(self):
        self.model.eval()

        train_latent = []
        test_normal_latent = []
        test_anomaly_latent = []
        with torch.no_grad():
            for x, y in self.train_loader:
                x = x.float().to(self.device)
                z, likelihood = self.model(x)
                train_latent.extend(z.cpu().detach().tolist())    

            for x, y in self.test_normal_loader:
                x = x.float().to(self.device)
                z, likelihood = self.model(x)
                test_normal_latent.extend(z.cpu().detach().tolist())

            for x, y in self.test_anomaly_loader:
                x = x.float().to(self.device)
                z, likelihood =self.model(x) 
                test_anomaly_latent.extend(z.cpu().detach().tolist())
        return torch.tensor(train_latent), torch.tensor(test_normal_latent), torch.tensor(test_anomaly_latent)
    
    def extract_volume(self):
        self.model.eval()

        train_logdet = []
        test_normal_logdet = []
        test_anomaly_logdet = []
        with torch.no_grad():
            for x, y in self.train_loader:
                x = x.float().to(self.device)
                logdet = self.model.volume_extract(x)
                train_logdet.extend(logdet.cpu().detach().tolist())    

            for x, y in self.test_normal_loader:
                x = x.float().to(self.device)
                logdet = self.model.volume_extract(x)
                test_normal_logdet.extend(logdet.cpu().detach().tolist())

            for x, y in self.test_anomaly_loader:
                x = x.float().to(self.device)
                logdet = self.model.volume_extract(x) 
                test_anomaly_logdet.extend(logdet.cpu().detach().tolist())
        return torch.tensor(train_logdet), torch.tensor(test_normal_logdet), torch.tensor(test_anomaly_logdet)
