# -*- coding: utf-8 -*-
# @Time    : 2022/7/28 0:43
# @File    : basic_funs.py
# @Software: PyCharm
import numpy as np
import math
import random
import torch
import torch.nn as nn
import time


class Timer:
    """Record multiple running times."""

    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score > self.best_score + self.delta:
            self.counter += 1
            #             print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        #         if self.verbose:
        #             print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss


def weight_init(m):
    if isinstance(m, nn.Linear):  # if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class MultivariateBetaDataset(torch.utils.data.Dataset):
    def __init__(self, N, dim, alpha, beta):
        self.N = N
        self.alpha = alpha
        self.beta = beta
        self.dim = dim

        self.dist = self.build_dist
        self.x, self.logpdf = self.DataSampling

    #         self.MI = self.True_MI

    def __getitem__(self, ix):
        a = self.x[ix, :]
        return a

    def __len__(self):
        return self.N

    @property
    def build_dist(self):
        dist = torch.distributions.beta.Beta(torch.tensor(self.alpha), torch.tensor(self.beta))
        return dist

    @property
    def DataSampling(self):
        Data = torch.zeros(self.N, self.dim)
        temp = torch.zeros(self.N, 1)

        for i in range(self.dim):
            sample = self.dist.sample((self.N, 1))
            logpdf_temp = self.dist.log_prob(sample)
            #             print(logpdf_temp.shape)
            temp = temp + logpdf_temp
            #             print(temp.shape)
            #             print(sample.shape)
            Data[:, i] = sample[:, 0]
        #         normalvalue = m.log_prob(Data).unsqueeze(1)
        #         Oracle_MI_Est = (torch.log(temp) - normalvalue).mean()
        return Data, temp


class bregmanFNN(nn.Module):

    def __init__(self, dim, width_vec: list = None):
        super(bregmanFNN, self).__init__()
        self.dim = dim
        self.width_vec = width_vec

        modules = []
        if width_vec is None:
            width_vec = [dim, 16, 8]

        # Network
        for i in range(len(width_vec) - 1):
            modules.append(
                nn.Sequential(
                    nn.Linear(width_vec[i], width_vec[i + 1]),
                    nn.ReLU()))

        self.net = nn.Sequential(*modules,
                                 nn.Linear(width_vec[-1], 1))

    def forward(self, x):
        Ratio_est = self.net(x)
        return Ratio_est


def Training_breprocess(MyNet, trainer, mydataLoaderS, mydataLoaderT, Test_DataS, Test_DataT, num_epochs):
    patience = 100
    early_stopping = EarlyStopping(patience, verbose=True)
    for epoch in range(num_epochs):
        dataloader_iterator = iter(mydataLoaderS)
        #         print('ITERATION:',epoch)
        for BatchDataT in mydataLoaderT:
            try:
                BatchDataS = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(mydataLoaderS)
                BatchDataS = next(dataloader_iterator)
            trainer.zero_grad()
            l = lr_bregman(MyNet(BatchDataT), MyNet(BatchDataS))
            l.backward()
            trainer.step()
        nl2 = lr_bregman(MyNet(Test_DataT), MyNet(Test_DataS)).detach().numpy().item()
        early_stopping(nl2, MyNet)
        if early_stopping.early_stop:
            print('Early stopping at ITERATION:', epoch)
            break
    #     MyNet.load_state_dict(torch.load('checkpoint.pt'))
    return early_stopping.best_score


def lr_bregman(D_hat_q, D_hat_p):  # q in numerator
    Loss = torch.mean(torch.log(torch.exp(-D_hat_q) + 1)) + torch.mean(torch.log(torch.exp(D_hat_p) + 1))
    return Loss


def accuracy_eval(ldr_est, ldr):
    L2_norm_loss = torch.sqrt(torch.mean((ldr_est - ldr) ** 2))
    return L2_norm_loss


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True