from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
import normflows as nf
from copy import deepcopy
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from util import *

class Flow_Trainer_cnn():      
    def __init__(self, model, epochs, in_train_loader, in_test_loader, out_test_loader,
                  optimizer, scheduler, device, in_datasetname, out_datasetname, n_dims):
        self.model = model
        self.epochs = epochs
        self.in_train_loader = in_train_loader
        self.in_test_loader =  in_test_loader
        self.out_test_loader = out_test_loader
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.in_datasetname = in_datasetname
        self.out_datasetname = out_datasetname
        self.n_dims = n_dims

    def fit(self):
        self.model.train()
        loss_hist = np.array([])
        for i in tqdm(range(self.epochs)):
            loss_sum = 0
            minibatch_cnt = 0
            for x in self.in_train_loader:
                self.optimizer.zero_grad()
                loss = self.model.forward_kld(x.to(self.device))
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    self.optimizer.step()
                    loss_sum += float(loss.detach().cpu())
                    minibatch_cnt += 1
                    del(x,loss)
                else:
                    print("NaN/Inf Detected!")
                if self.epochs == 1:
                    if minibatch_cnt > 1:
                        break
                nf.utils.update_lipschitz(self.model, n_iterations=3)

            loss_mean = loss_sum / minibatch_cnt

            if (i+1) % 1 ==0:
                print(f"({i+1}/{self.epochs}) /// Loss : {loss_mean}")
            if self.scheduler is not None:
                self.scheduler.step()
            #loss_hist = np.append(loss_hist)
        return loss_hist
    
    def extract_features(self):
        in_ll = []
        in_z =[]
        in_log_det = []
        out_ll = []
        out_z = []
        out_log_det = []
        n = 0
        bpd_cum = 0
        self.model.eval()
        with torch.no_grad():
            for i, x in enumerate(iter(self.in_test_loader)):
                z, log_det = self.model.inverse_and_log_det(x.to(self.device))
                if i==0:
                    z0_shape = z[0].shape
                    z1_shape = z[1].shape
                    z2_shape = z[2].shape
                first_z = z[0].reshape(-1, z0_shape[1] *z0_shape[2]*z0_shape[3])
                second_z = z[1].reshape(-1, z1_shape[1]*z1_shape[2]*z1_shape[3])
                third_z = z[2].reshape(-1, z2_shape[1]*z2_shape[2]*z2_shape[3])

                z= torch.cat([first_z, second_z, third_z],axis=1)
                nll_np = -(log_prob(z)+log_det).cpu().detach().numpy()
                in_z.append(z)
                in_log_det.append(log_det)    
                in_ll.extend((nll_np*(-1)).tolist())

                bpd_cum += np.nansum(nll_np / np.log(2) / self.n_dims + 8)
                n += len(x) - np.sum(np.isnan(nll_np))

            in_z_tensor = torch.cat(in_z,axis=0)
            in_z_ll = log_prob(in_z_tensor).cpu().detach().tolist()
            in_log_det = torch.cat(in_log_det).cpu().detach().tolist()
            
        #print(f'{self.in_datasetname} Bits per dim: ', bpd_cum / n)

        n = 0
        bpd_cum = 0
        with torch.no_grad():
            for i, x in enumerate(iter(self.out_test_loader)):
                z, log_det = self.model.inverse_and_log_det(x.to(self.device))
                if i==0:
                    z0_shape = z[0].shape
                    z1_shape = z[1].shape
                    z2_shape = z[2].shape
                first_z = z[0].reshape(-1, z0_shape[1] *z0_shape[2]*z0_shape[3])
                second_z = z[1].reshape(-1, z1_shape[1]*z1_shape[2]*z1_shape[3])
                third_z = z[2].reshape(-1, z2_shape[1]*z2_shape[2]*z2_shape[3])

                z= torch.cat([first_z, second_z, third_z],axis=1)
                nll_np = -(log_prob(z)+log_det).cpu().detach().numpy()
                out_z.append(z)
                out_log_det.append(log_det)    
                out_ll.extend((nll_np*(-1)).tolist())
                bpd_cum += np.nansum(nll_np / np.log(2) / self.n_dims + 8)
                n += len(x) - np.sum(np.isnan(nll_np))

            out_z_tensor = torch.cat(out_z,axis=0)
            out_z_ll = log_prob(out_z_tensor).cpu().detach().tolist()
            out_log_det = torch.cat(out_log_det).cpu().detach().tolist()
            print(f'{self.out_datasetname} Bits per dim: ', bpd_cum / n)

        return in_ll, in_z_ll, in_log_det, in_z_tensor, out_ll, out_z_ll, out_log_det, out_z_tensor
    



class Flow_Trainer_realnvp():      
    def __init__(self, model, epochs, in_train_loader, in_test_loader, out_test_loader,
                  optimizer, scheduler, device, in_datasetname, out_datasetname, n_dims):
        self.model = model
        self.epochs = epochs
        self.in_train_loader = in_train_loader
        self.in_test_loader =  in_test_loader
        self.out_test_loader = out_test_loader
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.in_datasetname = in_datasetname
        self.out_datasetname = out_datasetname
        self.n_dims = n_dims

    def fit(self):
        self.model.train()
        loss_hist = np.array([])
        for i in tqdm(range(self.epochs)):
            loss_sum = 0
            minibatch_cnt = 0
            for x in self.in_train_loader:
                self.optimizer.zero_grad()
                loss = self.model.forward_kld(x.to(self.device))
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    self.optimizer.step()
                    loss_sum += loss
                    minibatch_cnt += 1
                    del(x,loss)
                else:
                    print("NaN/Inf Detected!")
            loss_mean = loss_sum / minibatch_cnt

            if (i+1) % 20 ==0:
                print(f"({i+1}/{self.epochs}) /// Loss : {loss_mean}")
            if self.scheduler is not None:
                self.scheduler.step()
            loss_hist = np.append(loss_hist, loss_mean.detach().to('cpu').numpy())
        return loss_hist
    
    def extract_features(self):
        self.model.eval()
        in_ll = []
        out_ll = []

        with torch.no_grad():
            for i, x in enumerate(iter(self.in_test_loader)):
                ll = self.model.log_prob(x.to(self.device))
                ll_np = ll.cpu().numpy()
                in_ll.extend(ll_np.tolist())


        with torch.no_grad():
            for i, x in enumerate(iter(self.out_test_loader)):
                ll = self.model.log_prob(x.to(self.device))
                ll_np = ll.cpu().numpy()
                out_ll.extend(ll_np.tolist())

        in_ll = np.array(in_ll)
        out_ll = np.array(out_ll)

        in_ll = in_ll[~np.isnan(in_ll)]
        in_ll = in_ll[~np.isinf(in_ll)].tolist()

        out_ll = out_ll[~np.isnan(out_ll)]
        out_ll = out_ll[~np.isinf(out_ll)].tolist()
        return in_ll, out_ll



class Flow_Trainer_nsf():      
    def __init__(self, model, epochs, in_train_loader, in_test_loader, out_test_loader,
                  optimizer, scheduler, device, in_datasetname, out_datasetname, n_dims):
        self.model = model
        self.epochs = epochs
        self.in_train_loader = in_train_loader
        self.in_test_loader =  in_test_loader
        self.out_test_loader = out_test_loader
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.in_datasetname = in_datasetname
        self.out_datasetname = out_datasetname
        self.n_dims = n_dims

    def fit(self):
        self.model.train()
        loss_hist = np.array([])
        for i in tqdm(range(self.epochs)):
            loss_sum = 0
            minibatch_cnt = 0
            for x in self.in_train_loader:
                self.optimizer.zero_grad()
                loss = self.model.forward_kld(x.to(self.device))
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    self.optimizer.step()
                    loss_sum += loss
                    minibatch_cnt += 1
                    del(x,loss)
                else:
                    print("NaN/Inf Detected!")
            loss_mean = loss_sum / minibatch_cnt

            if (i+1) % 20 ==0:
                print(f"({i+1}/{self.epochs}) /// Loss : {loss_mean}")
            if self.scheduler is not None:
                self.scheduler.step()
            loss_hist = np.append(loss_hist, loss_mean.detach().to('cpu').numpy())
        return loss_hist
    
    def extract_features(self):
        in_ll = []
        in_z =[]
        in_log_det = []
        out_ll = []
        out_z = []
        out_log_det = []
        n = 0
        bpd_cum = 0
        self.model.eval()
        with torch.no_grad():
            for i, x in enumerate(iter(self.in_test_loader)):
                ll = self.model.log_prob(x.to(self.device))
                ll_np = ll.cpu().numpy()
                z, log_det = self.model.inverse_and_log_det(x.to(self.device))

                in_z.append(z)
                in_log_det.append(log_det)    
                in_ll.extend((ll_np).tolist())

                bpd_cum += np.nansum(-ll_np / np.log(2) / self.n_dims + 8)
                n += len(x) - np.sum(np.isnan(-ll_np))

            in_z_tensor = torch.cat(in_z,axis=0)
            in_z_ll = log_prob(in_z_tensor).cpu().detach().tolist()
            in_log_det = torch.cat(in_log_det).cpu().detach().tolist()
            
        print(f'{self.in_datasetname} Bits per dim: ', bpd_cum / n)

        n = 0
        bpd_cum = 0
        with torch.no_grad():
            for i, x in enumerate(iter(self.out_test_loader)):
                ll = self.model.log_prob(x.to(self.device))
                ll_np = ll.cpu().numpy()
                z, log_det = self.model.inverse_and_log_det(x.to(self.device))

                out_z.append(z)
                out_log_det.append(log_det)    
                out_ll.extend((ll_np).tolist())

                bpd_cum += np.nansum(-ll_np / np.log(2) / self.n_dims + 8)
                n += len(x) - np.sum(np.isnan(-ll_np))

            out_z_tensor = torch.cat(out_z,axis=0)
            out_z_ll = log_prob(out_z_tensor).cpu().detach().tolist()
            out_log_det = torch.cat(out_log_det).cpu().detach().tolist()
            print(f'{self.out_datasetname} Bits per dim: ', bpd_cum / n)

        return in_ll, in_z_ll, in_log_det, in_z_tensor, out_ll, out_z_ll, out_log_det, out_z_tensor

