import os
import random
import datetime

import math
from typing import List, Any, Tuple, Dict
from abc import ABCMeta, abstractmethod

import numpy as np

from sklearn.metrics import mean_squared_error, mean_absolute_error

import torch
from torch import nn, Tensor
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader


import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from lib.dre.common.util import ABCProbRateDense

def calc_Lp_error(esitmate_rate: Any, true_rate:Any, p:float):
    assert p >= 1, 'p must be 1.0 or greater.'
    abs_errors_power_p = np.power(
        np.absolute(esitmate_rate, true_rate), p)
    Lp = np.power(
            np.mean(abs_errors_power_p), 1/p)
    return Lp

def gen_estimate_density_rate_func(method: str):
    if method in ['alphaDiv']:
        def _estimate_density_rate_from_energy(
            denominator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """   
            T_P = prob_rate_model(
                'estimation',  denominator_data)
            mean_exp_minus_T_P = torch.mean(torch.exp(-T_P))
            estimated_dQdP = torch.exp(- T_P) / mean_exp_minus_T_P
            return estimated_dQdP
        return _estimate_density_rate_from_energy
    if method in [
            'LSIF-energy',
            'KLdivergence-energy', 
            'nnBD-KLdivergence-energy',
            'penalty-KLdivergence-energy']:
        def _estimate_density_rate_from_energy(
            denominator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """   
            T_P = prob_rate_model(
                'estimation',  denominator_data)
            mean_exp_T_P = torch.mean(torch.exp(T_P))
            estimated_dQdP = torch.exp(T_P) / mean_exp_T_P
            return estimated_dQdP
        return _estimate_density_rate_from_energy
    elif method in ['LSIF', 
                    'nnBD-LSIF',
                    'penalty-LSIF',
                    'KLdivergence',
                    'nnBD-KLdivergence',
                    'penalty-KLdivergence',
                    'GAN',
                    'alphaDiv-biased',
                    'alphaDiv-biased-truncated']:  
        def _estimate_density_rate_raw(
            denominator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """   
            raw_estimated_dQdP = prob_rate_model(
                'estimation',  denominator_data)
            mean_raw_estimated_dQdP = torch.mean(raw_estimated_dQdP)
            estimated_dQdP = raw_estimated_dQdP / mean_raw_estimated_dQdP
            return raw_estimated_dQdP
        return _estimate_density_rate_raw

def gen_estimate_target_divergence_func(
    method: str,
    params: dict):
    if method == 'alphaDiv':
        def _estimate_target_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            alpha = params['alpha']
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                 denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]
            exp_alpha_energy_Q = torch.exp(alpha*energy_Q)
            exp_alpha_m_one_energy_P = torch.exp((alpha - 1)*energy_P)
            temp_Q = torch.mean(exp_alpha_energy_Q)
            temp_P = torch.mean(exp_alpha_m_one_energy_P)
            loss_de = temp_P/(1 - alpha)
            loss_nu = temp_Q/alpha 
            loss = loss_nu + loss_de
            div = (1/(alpha*(1 - alpha)) - loss).item()
            return div
        return _estimate_target_div_func
    elif method in ['alphaDiv-biased', 'alphaDiv-biased-truncated']:
        def _estimate_target_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            alpha = params['alpha']
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            temp_Q = torch.mean(torch.pow(Rate_Q[0], alpha))
            temp_P = torch.mean(torch.pow(Rate_P[0], (alpha - 1)))
            loss_de = temp_P/(1 - alpha)
            loss_nu = temp_Q/alpha 
            loss = loss_nu + loss_de
            div = (1/(alpha*(1 - alpha)) - loss).item()
            return div
        return _estimate_target_div_func
    elif method in ['LSIF-energy']:
        def _estimate_target_div_func(
                    denominator_data: Tensor,
                    numerator_data: Tensor,
                    prob_rate_model: ABCProbRateDense) -> Tensor: 
                """
                """
                energy_P_list, energy_Q_list = prob_rate_model(
                        'optimization',
                        denominator_data, numerator_data)
                energy_P = energy_P_list[0]
                energy_Q = energy_Q_list[0]      
                Rate_P = torch.exp(energy_P)
                Rate_Q = torch.exp(energy_Q)  
                loss_de = torch.mean(Rate_P*Rate_P)/2.0
                loss_nu = - torch.mean(Rate_Q)
                loss = loss_nu + loss_de
                div = (1/2 - loss).item()
                return div
        return _estimate_target_div_func
    elif method in [
            'LSIF', 
            'nnBD-LSIF',
            'penalty-LSIF']:
        def _estimate_target_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0]*Rate_P[0])/2.0
            loss_nu = - torch.mean(Rate_Q[0])
            loss = loss_nu + loss_de
            div = (1/2 - loss).item()
            return div
        return _estimate_target_div_func
    elif method in [
            'KLdivergence',
            'nnBD-KLdivergence',
            'penalty-KLdivergence']:
        def _estimate_target_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization',
                 denominator_data, numerator_data)
            loss_de = torch.mean(Rate_P[0])
            loss_nu = - torch.mean(torch.log(Rate_Q[0]))
            loss = loss_nu + loss_de
            div = (1 - loss).item()
            return div
        return _estimate_target_div_func
    elif method == 'GAN': 
        def _estimate_target_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            """
            """
            Rate_P, Rate_Q = prob_rate_model(
                'optimization', 
                denominator_data, numerator_data)
            loss_de = torch.mean(
                - torch.log(1/(1+Rate_P[0])))
            loss_nu = torch.mean(
                - torch.log(Rate_Q[0]/(1+Rate_Q[0])))
            loss = loss_nu + loss_de
            div = - (loss_nu + loss_de).item()
            return div
        return _estimate_target_div_func
    elif method in ['KLdivergence-energy',
                    'nnBD-KLdivergence-energy',
                    'penalty-KLdivergence-energy']: 
        def _estimate_target_div_func(
            denominator_data: Tensor,
            numerator_data: Tensor,
            prob_rate_model: ABCProbRateDense) -> float:
            """
            """
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]
            loss_de = torch.mean(torch.exp(energy_P))
            loss_nu = - torch.mean(energy_Q)
            loss = loss_nu + loss_de
            div = (1 - loss).item()
            return div 
        return _estimate_target_div_func
    
def gen_estimate_KL_divergence_func(
    method: str,
    params: dict):
    if method in ['alphaDiv']: 
        def _estimate_KL_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                 denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]

            exp_plus_energy_Q = torch.exp(energy_Q)
            exp_minus_energy_P = torch.exp(-energy_P)
            mean_exp_plus_energy_Q = torch.mean(exp_plus_energy_Q)
            mean_minus_energy_P = torch.mean(exp_minus_energy_P)
            Rate_Q = torch.exp(energy_Q) / mean_exp_plus_energy_Q
            Rate_P = torch.exp(-energy_P) / mean_minus_energy_P
            loss_de = torch.mean(Rate_P)
            loss_nu = - torch.mean(-torch.log(Rate_Q))
            loss = loss_nu + loss_de
            estimate_KL_1 = (1 - loss).item()
            estimate_KL_2 = torch.mean(-torch.log(Rate_Q)).item()
            return estimate_KL_1, estimate_KL_2
        return _estimate_KL_div_func
    if method in [
                'LSIF-energy',
                'KLdivergence-energy', 
                'nnBD-KLdivergence-energy',
                'penalty-KLdivergence-energy']: 
        def _estimate_KL_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            energy_P_list, energy_Q_list = prob_rate_model(
                'optimization',
                 denominator_data, numerator_data)
            energy_P = energy_P_list[0]
            energy_Q = energy_Q_list[0]
            loss_de = torch.mean(torch.exp(energy_P))
            loss_nu = - torch.mean(energy_Q)
            loss = loss_nu + loss_de
            estimate_KL_1 = (1 - loss).item()
            Rate_P_raw = torch.exp(energy_P)
            Rate_P = Rate_P_raw / torch.mean(Rate_P_raw)
            estimate_KL_2 = torch.mean(Rate_P*torch.log(Rate_P)).item()
            return estimate_KL_1, estimate_KL_2
        return _estimate_KL_div_func
    elif method in [
            'alphaDiv-biased',
            'alphaDiv-biased-truncated',
            'LSIF', 
            'nnBD-LSIF',
            'penalty-LSIF',
            'KLdivergence',
            'nnBD-KLdivergence',
            'penalty-KLdivergence',
            'GAN'
            ]:
        def _estimate_KL_div_func(
                denominator_data: Tensor,
                numerator_data: Tensor,
                prob_rate_model: ABCProbRateDense) -> Tensor: 
            Rate_P_raw, Rate_Q_raw = prob_rate_model(
                'optimization',
                denominator_data, numerator_data)
            Rate_P = Rate_P_raw[0] / torch.mean(Rate_P_raw[0])
            #Rate_Q = Rate_Q_raw[0] / torch.mean(Rate_Q_raw[0])
            loss_de = torch.mean(Rate_P_raw[0])
            loss_nu = - torch.mean(torch.log(Rate_Q_raw[0]))
            loss = loss_nu + loss_de
            estimate_KL_1 = (1 - loss).item()
            estimate_KL_2 = torch.mean(Rate_P*torch.log(Rate_P)).item()
            return estimate_KL_1, estimate_KL_2
            return div
        return _estimate_KL_div_func
 
    
