import os
import random
import datetime

from typing import List, Any, Tuple, Dict

import numpy as np

from sklearn.metrics import mean_squared_error

import torch
from torch import nn, Tensor

from lib.dre.common.util import ABCProbRateDense, Truncated


def gen_loss_func(
    method: str,
    params: dict):
    if method == 'alphaDiv':
        def _target_div_loss(
                  denominator_data: Tensor,
                  numerator_data: Tensor,
                  prob_rate_model: ABCProbRateDense):
            alpha = params['alpha']
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]
            exp_alpha_energy_Q = torch.exp(alpha*energy_Q)
            exp_alpha_m_one_energy_P = torch.exp((alpha - 1)*energy_P)
            temp_Q = torch.mean(exp_alpha_energy_Q)
            temp_P = torch.mean(exp_alpha_m_one_energy_P)
            loss_de = temp_P/(1 - alpha)
            loss_nu = temp_Q/alpha 
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    if method == 'alphaDiv-biased':
        def _target_div_loss(
                  denominator_data: Tensor,
                  numerator_data: Tensor,
                  prob_rate_model: ABCProbRateDense):
            alpha = params['alpha']
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            temp_Q = torch.mean(torch.pow(Rate_Q[0], alpha))
            temp_P = torch.mean(torch.pow(Rate_P[0], (alpha - 1)))
            loss_de = temp_P/(1 - alpha)
            loss_nu = temp_Q/alpha 
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    if method == 'alphaDiv-biased-truncated':
        def _target_div_loss(
                  denominator_data: Tensor,
                  numerator_data: Tensor,
                  prob_rate_model: ABCProbRateDense):
            alpha = params['alpha']
            max_rate = params['max_rate']
            truncated = Truncated(min=0, max=max_rate)
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            temp_Q = torch.mean(torch.pow(
                truncated(Rate_Q[0]), alpha))
            temp_P = torch.mean(torch.pow(truncated(Rate_P[0]), (alpha - 1)))
            loss_de = temp_P/(1 - alpha)
            loss_nu = temp_Q/alpha 
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    elif method == 'LSIF': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense
            ) -> float:
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0]*Rate_P[0])/2.0
            loss_nu = - torch.mean(Rate_Q[0])
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    elif method == 'LSIF-energy': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense
            ) -> float:
            """
            """
            energy_P_list, energy_Q_list = prob_rate_model(
                    'optimization',
                    denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]      
            Rate_P = torch.exp(energy_P)
            Rate_Q = torch.exp(energy_Q)    
            loss_de = torch.mean(Rate_P*Rate_P)/2.0
            loss_nu = - torch.mean(Rate_Q)
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    elif method == 'nnBD-LSIF': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense
            ) -> float:
            """
            """
            C = params['C']
            Rate_P, Rate_Q = prob_rate_model(
                'optimization', 
                denominator_data, numerator_data)
            ReLU = nn.ReLU()
            loss_1_plus = ReLU(torch.mean(Rate_P[0]*Rate_P[0])/2.0
                               - C*torch.mean(Rate_Q[0]*Rate_Q[0])/2.0)
            loss_2 = - torch.mean(Rate_Q[0] - C*Rate_Q[0]*Rate_Q[0])/2.0
            loss = loss_1_plus + loss_2
            return loss         
        return _target_div_loss  
    elif method == 'penalty-LSIF': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense
            ) -> float:
            """
            """
            max_rate = params['max_rate']
            penalty = params['penalty']
            ReLU = nn.ReLU()
            Rate_P, Rate_Q = prob_rate_model(
                'optimization', 
                denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0]*Rate_P[0])/2.0
            loss_nu = - torch.mean(Rate_Q[0])
            loss = loss_nu + loss_de + penalty*ReLU(- loss_nu - max_rate)
            return loss      
        return _target_div_loss
    elif method == 'KLdivergence': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0])
            loss_nu = - torch.mean(torch.log(Rate_Q[0]))
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss
    elif method == 'nnBD-KLdivergence': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            C = params['C']
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            ReLU = nn.ReLU()
            loss_de_plus = ReLU(torch.mean(Rate_P[0])
                               - C*torch.mean(Rate_Q[0]))
            loss_nu = - torch.mean(torch.log(Rate_Q[0]) - C*Rate_Q[0])
            loss = loss_nu + loss_de_plus          
            return loss
        return _target_div_loss
    elif method == 'nnBD-KLdivergence-energy': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            C = params['C']
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            energy_P = energy_P_list[0] #Rate_P[0] = torch.exp(energy_P)
            energy_Q = energy_Q_list[0] #torch.log(Rate_Q[0]) = energy_Q
            Rate_P = torch.exp(energy_P)
            Rate_Q = torch.exp(energy_Q)               
            ReLU = nn.ReLU()
            loss_1_plus = ReLU(torch.mean(Rate_P)
                               - C*torch.mean(Rate_Q))
            loss_2 = - torch.mean(energy_Q- C*Rate_Q)
            loss = loss_1_plus + loss_2           
            return loss
        return _target_div_loss
    elif method == 'penalty-KLdivergence': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            max_rate = params['max_rate']
            penalty = params['penalty']
            ReLU = nn.ReLU()
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0])
            loss_nu = - torch.mean(torch.log(Rate_Q[0]))
            loss = (
                loss_nu + loss_de
                 + penalty*ReLU(- loss_nu - np.log(max_rate)))
            return loss           
        return _target_div_loss
    elif method == 'penalty-KLdivergence-energy': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            max_rate = params['max_rate']
            penalty = params['penalty']
            ReLU = nn.ReLU()
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            energy_P = energy_P_list[0] 
            energy_Q = energy_Q_list[0] 
            loss_de = torch.mean(torch.exp(energy_P))
            loss_nu = - torch.mean(energy_Q)
            loss = (
                loss_nu + loss_de
                 + penalty*ReLU(- loss_nu - np.log(max_rate)))
            return loss           
        return _target_div_loss
    elif method == 'KLdivergence-energy': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            ReLU = nn.ReLU()
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]
            loss_de = torch.mean(torch.exp(energy_P))
            loss_nu = - torch.mean(energy_Q)
            loss = loss_nu + loss_de
            return loss           
        return _target_div_loss
    elif method == 'GAN': 
        def _target_div_loss(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense
            ) -> float:
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization', 
                denominator_data, numerator_data)
            loss_de = torch.mean(
                torch.log(1 + Rate_P[0]))
            loss_nu = torch.mean(
                torch.log(1 + 1/Rate_Q[0]))
            loss = loss_nu + loss_de
            return loss
        return _target_div_loss

