import nflows.distributions.normal
import torch
import numpy as np
import pandas as pd

from torch import optim
from nflows.flows.base import Flow
from torch.utils.data import DataLoader
from torch.distributions.normal import Normal
from torch.distributions.studentT import StudentT
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.normalization import BatchNorm
from nflows.transforms.lu import LULinear

import os, os.path
import sys
import inspect
from sklearn.model_selection import train_test_split

from synthetic_experiments.data_generators import copula_generator

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from utils.distributions import tDist, norm_tDist
from utils.tail_permutation import TailRandomPermutation, RandomPermutation

if torch.cuda.is_available():
    torch.device("cuda")
    device = "cuda"
else:
    torch.device("cpu")
    device = "cpu"

class mTAF:
    def __init__(self, data, num_heavy, df, num_layers=5, num_hidden=200, num_blocks=1, batch_norm=False, dropout_prob=0.0, batch_size=256, model="maf", linear_layer="permutation", model_nr=0, track_results=False):
        self.data = data
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.num_blocks = num_blocks
        self.batch_norm = batch_norm
        self.dropout = dropout_prob
        self.batch_size = batch_size

        self.marginals = []
        self.heavy_tailed = []
        self.list_tailindexes = []
        self.model = model
        self.linear_layer = linear_layer

        self.model_nr = model_nr
        self.track_results = track_results

        # 1. Read the data
        self.D = int(self.data)
        self.data = ""
        num_samps = 25000 # 10000, 30000, or 5000
        data = copula_generator(self.D, num_heavy, df).get_data(num_samps)
        self.data_train, self.data_val = train_test_split(data, test_size=3/5)
        self.data_val, self.data_test = train_test_split(self.data_val, test_size=2/3)

        self.PATH_model = "models/mtaf_df" + str(df) + "h" + str(num_heavy)

    def tail_estimation(self):
        for j in range(self.D):
            marginal_data = np.abs(self.data_val[:, j])
            df = pd.DataFrame()
            df["data"] = marginal_data
            df["helper"] = np.repeat(1, len(marginal_data))
            PATH_marg = "data/marginal" + str(j + 1)
            PATH_tailest = "data/tail_estimator" + str(j + 1) + ".txt"
            np.savetxt(PATH_marg + ".dat", df.values, fmt=["%10.5f", "%d"])

            script = "python ../utils/tail_estimation.py " + PATH_marg + ".dat " + PATH_marg + "_results.pdf --noise 0 --path_estimator " + PATH_tailest

            os.system(script)

    def config(self, PATH=""):
        # 1. Get the Marginals
        for j in range(self.D):
            if PATH=="":
                PATH_tailest = "data/tail_estimator" + str(j + 1) + ".txt"
            else:
                PATH_tailest = PATH + str(j + 1) + ".txt"
            tail_index = np.loadtxt(PATH_tailest)

            # set tail_index > 20 to light-tailed distribution
            if tail_index > 20:
                tail_index = 0
            self.list_tailindexes.append(tail_index)
            if tail_index == 0:
                self.marginals.append(Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device)))
                self.heavy_tailed.append(False)
                print("{}th Marginal is detected as light-tailed.".format(j + 1))
            else:
                self.marginals.append(StudentT(torch.tensor(tail_index).to(device), torch.tensor([0.0]).to(device),
                                               torch.tensor([1.0]).to(device)))
                self.heavy_tailed.append(True)
                print("{}th Marginal is detected as heavy-tailed with tail-index {}.".format(j + 1, tail_index))
        num_heavy = np.sum(np.array(self.heavy_tailed))
        num_light = self.D - num_heavy
        # 2. Reorder the Marginals
        self.permutation = np.argsort(np.array(self.heavy_tailed))
        self.inv_perm = np.zeros(self.D, dtype=np.int32)  # for reordering
        for j in range(self.D):
            self.inv_perm[self.permutation[j]] = j
        self.marginals_permuted = []
        for d in range(self.D):
            self.marginals_permuted.append(self.marginals[self.permutation[d]])

        self.tail_index_permuted = np.array(self.list_tailindexes)[self.permutation]

        self.base_dist = norm_tDist([self.D], self.tail_index_permuted)

        transforms = []
        for _ in range(self.num_layers):
            transforms.append(TailRandomPermutation(num_light, num_heavy))
            transforms.append(MaskedAffineAutoregressiveTransform(features=self.D,
                                                                  hidden_features=self.num_hidden,
                                                                  num_blocks=self.num_blocks,
                                                                  use_batch_norm=self.batch_norm,
                                                                  dropout_probability=self.dropout))
            transforms.append(BatchNorm(features=self.D))
        self.transform = CompositeTransform(transforms)

        self.flow = Flow(self.transform, self.base_dist).to(device)

        # 5 adjust the Data
        self.data_train = self.data_train[:, self.permutation]
        self.data_test = self.data_test[:, self.permutation]
        self.data_val = self.data_val[:, self.permutation]

    def train(self, num_epochs=200, lr=1e-5, lr_wd=1e-6, lr_df=0.1, patience=30, cosine_annealing=False):
        optimizer = optim.Adam([{"params": self.flow._distribution.parameters(), "lr": lr_df},
                                {"params": self.flow._transform.parameters()},
                                {"params": self.flow._embedding_net.parameters()}
                                ], lr=lr, weight_decay=lr_wd)
        self.lr_df = lr_df
        if self.lr_df ==0.0:
            self.mtaf_type = "mTAF_fix"
            self.PATH_model = self.PATH_model + "_fix"
        else:
            self.mtaf_type = "mTAF"

        if cosine_annealing:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, 0)
        else:
            scheduler = None

        train_dataloader = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
        test_dataloader = DataLoader(self.data_test, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(self.data_val, batch_size=self.batch_size, shuffle=True)

        loss_val_list = []
        loss_trn_list = []
        self.ls_dfs = []
        counter_es = 0
        counter_discarded_batches = 0
        for e in range(num_epochs):
            loss_trainepoch = []
            for batch in train_dataloader:
                self.flow.train()
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -self.flow.log_prob(inputs=x).mean()
                if ~(-loss).isinf().any():
                    loss.backward()
                    optimizer.step()
                else:
                    counter_discarded_batches += 1
                    print("Discarded Batch Nr " + str(counter_discarded_batches))

                loss_trainepoch.append(loss.cpu().detach().numpy())

            if cosine_annealing:
                scheduler.step(e)
            else: # do nothing
                _ = 2

            loss_train = np.around(np.mean(loss_trainepoch), decimals=2)
            loss_trn_list.append(loss_train)
            self.flow.eval()
            val_batch = torch.tensor(next(iter(val_dataloader)), dtype=torch.float32).to(device)
            loss_val = np.around(torch.mean(-self.flow.log_prob(val_batch)).cpu().detach().numpy(), decimals=2)
            loss_val_list.append(loss_val)
            if loss_val > min(loss_val_list):
                if counter_es == 0:  # first time that val loss increases
                    # save model
                    torch.save(self.flow.state_dict(), self.PATH_model)
                elif counter_es == patience:  # stop training
                    break
                print(f'Early Stopping counter {counter_es + 1}/{patience}')
                counter_es += 1
            else:
                counter_es = 0  # reset counter

            print(f'Epoch {e + 1}/{num_epochs}: Train loss = {loss_train}, Validation loss = {loss_val}')
            print("Trained df:")
            for parameter in self.base_dist.parameters():
                df = parameter.detach().cpu().numpy()
                print(df)
                self.ls_dfs.append(df)

            if self.track_results:
                with open("results/likelihood/mtaf" + str(self.num_layers) + "_train.txt", "a") as f:
                    f.write(str(loss_train) + " " + str(self.model_nr) + "\n")
                with open("results/likelihood/mtaf" + str(self.num_layers) + "_val.txt", "a") as f:
                    f.write(str(loss_val) + " " + str(self.model_nr) + "\n")

        if counter_es > 0:
            self.flow = Flow(self.transform, self.base_dist).to(device)
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        torch.save(self.flow.state_dict(), self.PATH_model)
        # print test loss:
        with torch.no_grad():
            self.flow.eval()
            loss_test = []
            for batch in test_dataloader:
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                loss = -self.flow.log_prob(inputs=x).mean()
                loss_test.append(loss.cpu().detach().numpy())
        average_testloss = np.mean(loss_test)
        print("mTAF: Final Test loss after {} Epochs: {}".format(e + 1, average_testloss))

        if self.track_results:
            with open("results/likelihood/mtaf" + str(self.num_layers) + "_test.txt", "a") as f:
                f.write(str(average_testloss) + " " + str(self.model_nr) + "\n" )

    def load_model(self, path=""):
        if path=="":
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        else:
            self.flow.load_state_dict(torch.load(path, map_location=device))

    def save_dfs(self, PATH):
        num_files = len([name for name in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, name))])
        dfs = np.array(self.ls_dfs)
        np.save(PATH + "dfs_" + str(num_files + 1), dfs)

    def sample_with(self, base_samp, context=None):
        """base_samp is a sample from the base distribution. If we want to fix only specific components, set the unfixed components
        to 0."""
        noise = self.base_dist.sample(50)
        try:  # I can also insert multiple samples
            for j in range(self.D):
                comp = base_samp[j]
                if comp != 0:
                    for i in range(50):
                        noise[i, j] = torch.tensor(comp)
        except:
            for j in range(50):
                comp = base_samp[j]
                for i in range(self.D):
                    if comp[i] != 0:
                        noise[j, i] = torch.tensor(comp[i])
        self.flow.eval()
        samples, _ = self.flow._transform.inverse(noise)

        return samples

    def save_permutation(self, path=""):
        permutations = []
        ordering = np.arange(0, self.D)
        for j in range(self.num_layers):
            perm = self.flow._transform._transforms[int(3 * j)].get_permutation().detach().cpu().numpy()
            permutations.append(perm)
            ordering = ordering[perm]
        if path=="":
           np.save(self.PATH_model + "_ordering", ordering)
        else:
            np.save(path + "_ordering", ordering)

class TAF:
    def __init__(self, data,  num_heavy, df, num_layers=5, num_hidden=200, num_blocks=1, batch_norm=False, dropout_prob=0.0, batch_size=258, model="maf", linear_layer="permutation", model_nr=0, track_results=False):
        self.data = data
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.num_blocks = num_blocks
        self.batch_norm = batch_norm
        self.dropout = dropout_prob

        self.marginals = []
        self.heavy_tailed = []
        self.list_tailindexes = []
        self.model = model
        self.batch_size = batch_size
        self.linear_layer = linear_layer
        self.model_nr = model_nr
        self.track_results = track_results

        # 1. Read the data
        self.D = int(self.data)
        self.data = ""
        num_samps = 25000
        data = copula_generator(self.D, num_heavy, df).get_data(num_samps)
        self.data_train, self.data_val = train_test_split(data, test_size=3/5)
        self.data_val, self.data_test = train_test_split(self.data_val, test_size=2/3)

        self.PATH_model = "models/taf_df" + str(df) + "h" + str(num_heavy)

        self.base_dist = tDist([self.D])
        transforms = []
        for _ in range(self.num_layers):
            if self.linear_layer == "permutation":
                transforms.append(RandomPermutation(features=self.D))
            else:
                transforms.append(LULinear(features=self.D))
            transforms.append(MaskedAffineAutoregressiveTransform(features=self.D,
                                                                  hidden_features=self.num_hidden,
                                                                  num_blocks=self.num_blocks,
                                                                  dropout_probability=self.dropout,
                                                                  use_batch_norm=self.batch_norm))
            transforms.append(BatchNorm(features=self.D))

        self.transform = CompositeTransform(transforms)

        self.flow = Flow(self.transform, self.base_dist).to(device)

    def train(self, num_epochs=200, lr=1e-5, lr_wd=1e-6, lr_df=0.1, patience=30):
        optimizer = optim.Adam([{"params": self.flow._distribution.parameters(), "lr": lr_df},
                                {"params": self.flow._transform.parameters()},
                                {"params": self.flow._embedding_net.parameters()}
                                ], lr=lr, weight_decay=lr_wd)


        train_dataloader = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
        test_dataloader = DataLoader(self.data_test, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(self.data_val, batch_size=self.batch_size, shuffle=True)

        loss_val_list = []
        loss_trn_list = []
        counter_es = 0
        counter_discarded_batches = 0
        for e in range(num_epochs):
            loss_trainepoch = []
            for batch in train_dataloader:
                self.flow.train()
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -self.flow.log_prob(inputs=x).mean()
                if ~(-loss).isinf().any():
                    loss.backward()
                    optimizer.step()
                else:
                    counter_discarded_batches += 1
                    print("Nr of Discarded Batches: " + str(counter_discarded_batches))
                loss_trainepoch.append(loss.cpu().detach().numpy())

            loss_train = np.around(np.mean(loss_trainepoch), decimals=2)
            loss_trn_list.append(loss_train)
            self.flow.eval()
            val_batch = torch.tensor(next(iter(val_dataloader)), dtype=torch.float32).to(device)
            loss_val = np.around(torch.mean(-self.flow.log_prob(val_batch)).cpu().detach().numpy(), decimals=2)
            loss_val_list.append(loss_val)
            if loss_val > min(loss_val_list):
                if counter_es == 0:  # first time that val loss increases
                    # save model
                    torch.save(self.flow.state_dict(), self.PATH_model)
                elif counter_es == patience:  # stop training
                    break
                print(f'Early Stopping counter {counter_es + 1}/{patience}')
                counter_es += 1
            else:
                counter_es = 0  # reset counter

            print(f'Epoch {e + 1}/{num_epochs}: Train loss = {loss_train}, Validation loss = {loss_val}')
            print("Trained df:")
            for parameter in self.base_dist.parameters():
                df = parameter.detach().cpu().numpy()
                print(df)

            if self.track_results:
                with open("results/likelihood/taf" + str(self.num_layers) + "_train.txt", "a") as f:
                    f.write(str(loss_train) + " " + str(self.model_nr) + "\n")
                with open("results/likelihood/taf" + str(self.num_layers) + "_val.txt", "a") as f:
                    f.write(str(loss_val) + " " + str(self.model_nr) + "\n")
        if counter_es > 0:
            self.flow = Flow(self.transform, self.base_dist).to(device)
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        torch.save(self.flow.state_dict(), self.PATH_model)

        # print test loss:
        with torch.no_grad():
            self.flow.eval()
            loss_test = []
            for batch in test_dataloader:
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                loss = -self.flow.log_prob(inputs=x).mean()
                loss_test.append(loss.cpu().detach().numpy())
        average_testloss = np.mean(loss_test)
        print("TAF: Final Test loss after {} Epochs: {}".format(e + 1, average_testloss))

        if self.track_results:
            with open("results/likelihood/taf" + str(self.num_layers) + "_test.txt", "a") as f:
                f.write(str(average_testloss) + " " + str(self.model_nr) + "\n" )

    def load_model(self, path=""):
        if path=="":
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        else:
            self.flow.load_state_dict(torch.load(path, map_location=device))

    def sample_with(self, base_samp, context=None):
        """base_samp is a sample from the base distribution. If we want to fix only specific components, set the unfixed components
        to 0."""
        self.flow.eval()
        noise = self.base_dist.sample(50)
        try:  # I can also insert multiple samples
            for j in range(self.D):
                comp = base_samp[j]
                if comp != 0:
                    for i in range(50):
                        noise[i, j] = torch.tensor(comp)
        except:
            for j in range(50):
                comp = base_samp[j]
                for i in range(self.D):
                    if comp[i] != 0:
                        noise[j, i] = torch.tensor(comp[i])

        samples, _ = self.flow._transform.inverse(noise)

        return samples


    def save_permutation(self, path=""):
        # get the permutations:
        permutations = []
        ordering = np.arange(0, self.D)
        for j in range(self.num_layers):
            perm = self.flow._transform._transforms[int(3 * j)].get_permutation().detach().cpu().numpy()
            permutations.append(perm)
            ordering = ordering[perm]
        if path=="":
           np.save(self.PATH_model + "_ordering", ordering)
        else:
            np.save(path + "_ordering", ordering)

class Vanilla_Flow:
    def __init__(self, data,  num_heavy, df, num_layers=5, num_hidden=200, num_blocks=1, batch_norm=False, dropout_prob=0.0, batch_size=258, model="maf", linear_layer="permutation", model_nr=0, track_results=False):
        self.data = data
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.num_blocks = num_blocks
        self.batch_norm = batch_norm
        self.dropout = dropout_prob
        self.batch_size = batch_size

        self.marginals = []
        self.heavy_tailed = []
        self.list_tailindexes = []
        self.model = model
        self.linear_layer = linear_layer
        self.model_nr = model_nr
        self.track_results = track_results

        # 1. Read the data
        self.D = int(self.data)
        self.data = ""
        num_samps = 25000
        data = copula_generator(self.D, num_heavy, df).get_data(num_samps)
        self.data_train, self.data_val = train_test_split(data, test_size=2/5)
        self.data_val, self.data_test = train_test_split(self.data_val, test_size=3/5)

        self.PATH_model = "models/vanilla_df" + str(df) + "h" + str(num_heavy)

        marginals = []
        for j in range(self.D):
            marginals.append(Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device)))

        self.base_dist = nflows.distributions.normal.StandardNormal([self.D])

        transforms = []
        for _ in range(self.num_layers):
            if self.linear_layer == "permutation":
                transforms.append(RandomPermutation(features=self.D))
            else:
                transforms.append(LULinear(features=self.D))
            transforms.append(MaskedAffineAutoregressiveTransform(features=self.D,
                                                                  hidden_features=self.num_hidden,
                                                                  num_blocks=self.num_blocks,
                                                                  use_batch_norm=self.batch_norm,
                                                                  dropout_probability=self.dropout))
            transforms.append(BatchNorm(features=self.D))

        self.transform = CompositeTransform(transforms)

        self.flow = Flow(self.transform, self.base_dist).to(device)



    def train(self, num_epochs=200, lr=1e-5, lr_wd=1e-6, lr_df=0.1, patience=30, cosine_annealing = True):
        optimizer = optim.Adam(self.flow.parameters(), lr=lr, weight_decay=lr_wd)
        if cosine_annealing:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, 0)
        else:
            scheduler = None

        train_dataloader = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
        test_dataloader = DataLoader(self.data_test, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(self.data_val, batch_size=self.batch_size, shuffle=True)

        loss_val_list = []
        loss_trn_list = []
        counter_es = 0
        counter_discarded_batches = 0
        for e in range(num_epochs):
            loss_trainepoch = []
            for batch in train_dataloader:
                self.flow.train()
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -self.flow.log_prob(inputs=x).mean()
                if ~(-loss).isinf().any():
                    loss.backward()
                    optimizer.step()
                else:
                    counter_discarded_batches += 1
                    print("Nr of discarded Batches: " + str(counter_discarded_batches))
                loss_trainepoch.append(loss.cpu().detach().numpy())
            if cosine_annealing:
                scheduler.step(e)

            loss_train = np.around(np.mean(loss_trainepoch), decimals=2)
            loss_trn_list.append(loss_train)
            self.flow.eval()
            val_batch = torch.tensor(next(iter(val_dataloader)), dtype=torch.float32).to(device)
            loss_val = np.around(torch.mean(-self.flow.log_prob(val_batch)).cpu().detach().numpy(), decimals=2)
            loss_val_list.append(loss_val)
            if loss_val > min(loss_val_list) or np.isnan(loss_val):
                if counter_es == 0:  # first time that val loss increases
                    # save model
                    torch.save(self.flow.state_dict(), self.PATH_model)
                elif counter_es == patience:  # stop training
                    break
                print(f'Early Stopping counter {counter_es + 1}/{patience}')
                counter_es += 1
            else:
                counter_es = 0  # reset counter


            print(f'Epoch {e + 1}/{num_epochs}: Train loss = {loss_train}, Validation loss = {loss_val}')

            if self.track_results:
                with open("results/likelihood/vanilla" + str(self.num_layers) + "_train.txt", "a") as f:
                    f.write(str(loss_train) + " " + str(self.model_nr) + "\n")
                with open("results/likelihood/vanilla" + str(self.num_layers) + "_val.txt", "a") as f:
                    f.write(str(loss_val) + " " + str(self.model_nr) + "\n")

        if counter_es > 0:
            self.flow = Flow(self.transform, self.base_dist).to(device)
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        torch.save(self.flow.state_dict(), self.PATH_model)

        # print test loss:
        with torch.no_grad():
            self.flow.eval()
            loss_test = []
            for batch in test_dataloader:
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                loss = -self.flow.log_prob(inputs=x).mean()
                loss_test.append(loss.cpu().detach().numpy())
        average_testloss = np.mean(loss_test)
        print("Vanilla: Final Test loss after {} Epochs: {}".format(e + 1, average_testloss))

        if self.track_results:
            with open("results/likelihood/vanilla" + str(self.num_layers) + "_test.txt", "a") as f:
                f.write(str(average_testloss) + " " + str(self.model_nr) + "\n" )

    def load_model(self, path=""):
        if path=="":
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        else:
            self.flow.load_state_dict(torch.load(path, map_location=device))


    def sample_with(self, base_samp, context=None):
        """base_samp is a sample from the base distribution. If we want to fix only specific components, set the unfixed components
        to 0."""
        self.flow.eval()
        noise = self.base_dist.sample(len(base_samp))
        try:  # I can also insert multiple samples
            for j in range(self.D):
                comp = base_samp[j]
                if comp != 0:
                    for i in range(len(base_samp)):
                        noise[i, j] = torch.tensor(comp)
        except:
            for j in range(len(base_samp)):
                comp = base_samp[j]
                for i in range(self.D):
                    if comp[i] != 0:
                        noise[j, i] = torch.tensor(comp[i])
        samples, _ = self.flow._transform.inverse(noise)

        return samples

    def save_permutation(self, path=""):
        # get the permutations:
        permutations = []
        ordering = np.arange(0, self.D)
        for j in range(self.num_layers):
            perm = self.flow._transform._transforms[int(3 * j)].get_permutation().detach().cpu().numpy()
            permutations.append(perm)
            ordering = ordering[perm]
        if path=="":
           np.save(self.PATH_model + "_ordering", ordering)
        else:
            np.save(path + "_ordering", ordering)

class gTAF(mTAF):
    def __init__(self, data, num_heavy, df, num_layers=5, num_hidden=200, num_blocks=1, batch_norm=False, dropout_prob=0.0, batch_size=258, model="maf", linear_layer="permutation", model_nr=0, track_results=False):
        self.data = data
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.num_blocks = num_blocks
        self.batch_norm = batch_norm
        self.dropout = dropout_prob
        self.batch_size = batch_size

        self.marginals = []
        self.heavy_tailed = []
        self.list_tailindexes = []
        self.model = model
        self.linear_layer = linear_layer

        self.model_nr = model_nr
        self.track_results = track_results

        # 1. Read the data
        self.D = int(self.data)
        self.data = ""
        num_samps = 25000
        data = copula_generator(self.D, num_heavy, df).get_data(num_samps)
        self.data_train, self.data_val = train_test_split(data, test_size=3/5)
        self.data_val, self.data_test = train_test_split(self.data_val, test_size=2/3)

        self.PATH_model = "models/gTAF_df" + str(df) + "h" + str(num_heavy)

    def config(self):
        # 1. Get the Marginals
        for j in range(self.D):
            tail_index = 10
            self.list_tailindexes.append(tail_index)
            self.marginals.append(StudentT(torch.tensor(tail_index).to(device), torch.tensor([0.0]).to(device),
                                           torch.tensor([1.0]).to(device)))
            self.heavy_tailed.append(True)
        num_heavy = np.sum(np.array(self.heavy_tailed))
        num_light = self.D - num_heavy
        # 2. Reorder the Marginals
        self.permutation = np.argsort(np.array(self.heavy_tailed))
        self.inv_perm = np.zeros(self.D, dtype=np.int32)  # for reordering
        for j in range(self.D):
            self.inv_perm[self.permutation[j]] = j
        self.marginals_permuted = []
        for d in range(self.D):
            self.marginals_permuted.append(self.marginals[self.permutation[d]])

        self.tail_index_permuted = np.array(self.list_tailindexes)[self.permutation]

        self.base_dist = norm_tDist([self.D], self.tail_index_permuted)

        transforms = []
        for _ in range(self.num_layers):
            transforms.append(TailRandomPermutation(num_light, num_heavy))
            transforms.append(MaskedAffineAutoregressiveTransform(features=self.D,
                                                                  hidden_features=self.num_hidden,
                                                                  num_blocks=self.num_blocks,
                                                                  use_batch_norm=self.batch_norm,
                                                                  dropout_probability=self.dropout))
            transforms.append(BatchNorm(features=self.D))

        self.transform = CompositeTransform(transforms)

        self.flow = Flow(self.transform, self.base_dist).to(device)

        # 5 adjust the Data
        self.data_train = self.data_train[:, self.permutation]
        self.data_test = self.data_test[:, self.permutation]
        self.data_val = self.data_val[:, self.permutation]

    def train(self, num_epochs=200, lr=1e-5, lr_wd=1e-6, lr_df=0.1, patience=30, cosine_annealing=False):
        optimizer = optim.Adam([{"params": self.flow._distribution.parameters(), "lr": lr_df},
                                {"params": self.flow._transform.parameters()},
                                {"params": self.flow._embedding_net.parameters()}
                                ], lr=lr, weight_decay=lr_wd)
        if cosine_annealing:
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, 0)
        else:
            scheduler = None

        train_dataloader = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)
        test_dataloader = DataLoader(self.data_test, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(self.data_val, batch_size=self.batch_size, shuffle=True)

        loss_val_list = []
        loss_trn_list = []
        self.ls_dfs = []
        counter_es = 0
        for e in range(num_epochs):
            loss_trainepoch = []
            for batch in train_dataloader:
                self.flow.train()
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -self.flow.log_prob(inputs=x).mean()
                loss.backward()
                optimizer.step()

                loss_trainepoch.append(loss.cpu().detach().numpy())

            if cosine_annealing:
                scheduler.step(e)
            else:
                _ = 2

            loss_train = np.around(np.mean(loss_trainepoch), decimals=2)
            loss_trn_list.append(loss_train)
            self.flow.eval()
            val_batch = torch.tensor(next(iter(val_dataloader)), dtype=torch.float32).to(device)
            loss_val = np.around(torch.mean(-self.flow.log_prob(val_batch)).cpu().detach().numpy(), decimals=2)
            loss_val_list.append(loss_val)
            if loss_val > min(loss_val_list):
                if counter_es == 0:  # first time that val loss increases
                    # save model
                    torch.save(self.flow.state_dict(), self.PATH_model)
                elif counter_es == patience:  # stop training
                    break
                print(f'Early Stopping counter {counter_es + 1}/{patience}')
                counter_es += 1
            else:
                counter_es = 0  # reset counter

            print(f'Epoch {e + 1}/{num_epochs}: Train loss = {loss_train}, Validation loss = {loss_val}')
            print("Trained df:")
            for parameter in self.base_dist.parameters():
                df = parameter.detach().cpu().numpy()
                print(df)
                self.ls_dfs.append(df)

            if self.track_results:
                with open("results/likelihood/gtaf" + str(self.num_layers) + "_train.txt", "a") as f:
                    f.write(str(loss_train) + " " + str(self.model_nr) + "\n")
                with open("results/likelihood/gtaf" + str(self.num_layers) + "_val.txt", "a") as f:
                    f.write(str(loss_val) + " " + str(self.model_nr) + "\n")

        if counter_es > 0:
            self.flow = Flow(self.transform, self.base_dist).to(device)
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        torch.save(self.flow.state_dict(), self.PATH_model)
        # print test loss:
        with torch.no_grad():
            self.flow.eval()
            loss_test = []
            for batch in test_dataloader:
                x = torch.tensor(batch, dtype=torch.float32).to(device)
                loss = -self.flow.log_prob(inputs=x).mean()
                loss_test.append(loss.cpu().detach().numpy())
        average_testloss = np.mean(loss_test)
        print("gTAF: Final Test loss after {} Epochs: {}".format(e + 1, average_testloss))


        if self.track_results:
            with open("results/likelihood/gtaf" + str(self.num_layers) + "_test.txt", "a") as f:
                f.write(str(average_testloss) + " " + str(self.model_nr) + "\n" )

    def load_model(self, path=""):
        if path=="":
            self.flow.load_state_dict(torch.load(self.PATH_model, map_location=device))
        else:
            self.flow.load_state_dict(torch.load(path, map_location=device))

    def save_dfs(self, PATH):
        num_files = len([name for name in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, name))])
        dfs = np.array(self.ls_dfs)
        np.save(PATH + "dfs_" + str(num_files + 1), dfs)

    def sample_with(self, base_samp, context=None):
        """base_samp is a sample from the base distribution. If we want to fix only specific components, set the unfixed components
        to 0."""
        noise = self.base_dist.sample(len(base_samp))
        try:  # I can also insert multiple samples
            for j in range(self.D):
                comp = base_samp[j]
                if comp != 0:
                    for i in range(len(base_samp)):
                        noise[i, j] = torch.tensor(comp)
        except:
            for j in range(len(base_samp)):
                comp = base_samp[j]
                for i in range(self.D):
                    if comp[i] != 0:
                        noise[j, i] = torch.tensor(comp[i])
        samples, _ = self.flow._transform.inverse(noise)

        return samples

    def save_permutation(self, path=""):
        # get the permutations:
        permutations = []
        ordering = np.arange(0, self.D)
        for j in range(self.num_layers):
            perm = self.flow._transform._transforms[int(3 * j)].get_permutation().detach().cpu().numpy()
            permutations.append(perm)
            ordering = ordering[perm]
        if path=="":
           np.save(self.PATH_model + "_ordering", ordering)
        else:
            np.save(path + "_ordering", ordering)
