#!/usr/bin/env python
# -*- coding:utf-8 -*-
# software: PyCharm
import numpy as np
import math
import random
import torch
import torch.nn as nn
import time
import scipy.stats as st
from torch.distributions import MultivariateNormal
from torch.autograd import Variable
from numpy.linalg import *

from torchvision import datasets

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score > self.best_score + self.delta:
            self.counter += 1
            #             print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        #         if self.verbose:
        #             print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss


def weight_init(m):
    if isinstance(m, nn.Linear):  # if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


# def Training_breprocess(MyNet, trainer, mydataLoaderS, mydataLoaderT, Test_DataS, Test_DataT, num_epochs):
#     # patience = 1000
#     early_stopping = EarlyStopping(patience, verbose=True)
#     for epoch in range(num_epochs):
#         dataloader_iterator = iter(mydataLoaderS)
#         #         print('ITERATION:',epoch)
#         for BatchDataT in mydataLoaderT:
#             try:
#                 BatchDataS = next(dataloader_iterator)
#             except StopIteration:
#                 dataloader_iterator = iter(mydataLoaderS)
#                 BatchDataS = next(dataloader_iterator)
#             trainer.zero_grad()
#             l = lr_bregman(MyNet(BatchDataT), MyNet(BatchDataS))
#             l.backward()
#             trainer.step()
#         nl2 = lr_bregman(MyNet(Test_DataT), MyNet(Test_DataS)).detach().numpy().item()
#         early_stopping(nl2, MyNet)
#         if early_stopping.early_stop:
#             print('Early stopping at ITERATION:', epoch)
#             break
#     #     MyNet.load_state_dict(torch.load('checkpoint.pt'))
#     return early_stopping.best_score
def Training_breprocess(MyNet, trainer, mydataLoaderS, mydataLoaderT, num_epochs, device):
    for epoch in range(num_epochs):
        # print(epoch)
        dataloader_iterator = iter(mydataLoaderS)
        for BatchDataT in mydataLoaderT:
            try:
                BatchDataS = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(mydataLoaderS)
                BatchDataS = next(dataloader_iterator)
            trainer.zero_grad()
            l = lr_bregman(MyNet(BatchDataT.to(device)), MyNet(BatchDataS.to(device)))
            l.backward()
            trainer.step()


def lr_bregman(D_hat_q, D_hat_p):  # q in numerator
    Loss = torch.mean(torch.log(torch.exp(-D_hat_q) + 1)) + torch.mean(torch.log(torch.exp(D_hat_p) + 1))
    return Loss


def setup_seed(seed, cuda):
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)
