import os
import pickle
import datetime

from typing import List, Any, Tuple
import numpy as np

import torch
from torch import nn, Tensor
from torch import optim
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback


class AlphaInfosTracker(Callback):
    def __init__(self):
        self.collection = np.empty(0)

    def on_validation_batch_end(
          self,
          trainer, pl_module, outputs,
          batch, batch_idx, dataloader_idx):
        vals = outputs.cpu().detach().numpy()
        self.collection = np.append(self.collection, vals)


class NbwDense(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dims: List[int],
                 ):

        super().__init__()
        layers: List[nn.Module] = []
        _h_input_dim: int = input_dim
        for _h_out_dim in hidden_dims:
            layers.append(nn.Linear(_h_input_dim, _h_out_dim))
            layers.append(nn.ReLU())
            _h_input_dim = _h_out_dim
        output_dim = 1
        layers.append(nn.Linear(_h_input_dim, output_dim))
        self.layers_: nn.Module = nn.Sequential(*layers)

    def forward(self, data: torch.Tensor) -> torch.Tensor:
        return self.layers_(data)


class NbwDenseModelList(nn.Module):
    def __init__(self,
                 alpha: float,
                 params_for_nbw_models: dict):
        super().__init__()
        self.alpha_ = alpha
        self.input_dim_ = params_for_nbw_models['input_dim']
        self.hidden_dim_ = params_for_nbw_models['hidden_dim']
        self.n_layers_ = params_for_nbw_models['n_layers']
        self.hidden_dims_ = [self.hidden_dim_] * (self.n_layers_ - 1)
        self.nbw_models_ = nn.ModuleList()
        self.new_nbw_model_ = None

    def add_model(self):
        mdl = NbwDense(self.input_dim_,
                       self.hidden_dims_)
        self.new_nbw_model_ = mdl
        self.nbw_models_.append(self.new_nbw_model_)

    def load_and_add_model(self, nbw_model_parameters: dict):
        mdl = NbwDense(self.input_dim_,
                       self.hidden_dims_)
        mdl.load_state_dict(nbw_model_parameters)
        self.new_nbw_model_ = mdl
        self.nbw_models_.append(self.new_nbw_model_)

    def _forward_train(self,
                       data: torch.Tensor,
                       data_shuffled: torch.Tensor) -> List[torch.Tensor]:
        tmp_const_adjust_P = torch.Tensor([1]).type_as(
                          self.nbw_models_[0].layers_[0].weight)
        for _i_mdl in range(len(self.nbw_models_)-1):
            mdl = self.nbw_models_[_i_mdl]
            prd = mdl(data)
            prd_numerator_P = prd
            numerator_P = torch.exp(-prd_numerator_P)
            tmp_denominator_P = torch.mean(
                numerator_P)
            tmp_P = numerator_P/tmp_denominator_P
            tmp_const_adjust_P = tmp_const_adjust_P*tmp_P
        mean_tmp_const_adjust_P = torch.mean(
            tmp_const_adjust_P)
        const_previous_prob_adjusting_arr_P = (
            tmp_const_adjust_P
            / mean_tmp_const_adjust_P).detach()
        mdl = self.nbw_models_[len(self.nbw_models_)-1]
        preds_P = mdl(data)
        adjusting_prob_arr_P = (
          torch.exp(
            (self.alpha_ - 1)*preds_P)*const_previous_prob_adjusting_arr_P)

        n = data_shuffled.size()[0]
        r_idxs = torch.randperm(n)
        data_shuffled = data_shuffled[r_idxs, :]
        preds_Q = mdl(data_shuffled)
        adjusting_prob_arr_Q = torch.exp((1 - self.alpha_)*preds_Q)

        return adjusting_prob_arr_P, adjusting_prob_arr_Q

    def _forward_estimate(self,
                          data: torch.Tensor) -> List[torch.Tensor]:
        tmp_const_adjusting_arr_P = torch.Tensor([1]).type_as(
                                self.nbw_models_[0].layers_[0].weight)
        for _i_mdl in range(len(self.nbw_models_)):
            mdl = self.nbw_models_[_i_mdl]
            prd = mdl(data)
            prd_numerator_P = prd
            numerator_P = torch.exp(-prd_numerator_P)
            tmp_denominator_P = torch.mean(
                numerator_P)
            tmp_P = numerator_P/tmp_denominator_P
            tmp_const_adjusting_arr_P = tmp_const_adjusting_arr_P*tmp_P
        mean_tmp_const_adjusting_arr_P = torch.mean(
            tmp_const_adjusting_arr_P)
        const_adjusting_arr_P = (
            tmp_const_adjusting_arr_P
            / mean_tmp_const_adjusting_arr_P).detach()
        return const_adjusting_arr_P

    def forward(self,
                data: torch.Tensor,
                data_shuffled: torch.Tensor,
                pred_mode: str) -> List[torch.Tensor]:
        if pred_mode == 'train':
            return self._forward_train(data, data_shuffled)
        elif pred_mode == 'estimate_only':
            return self._forward_estimate(data)


class LightningNbwDense(pl.LightningModule):
    def __init__(self,
                 alpha: float,
                 params_nbw: dict,
                 learning_rate: float,
                 previous_nbw_model_filePath_list: List[str]
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.alpha_ = alpha
        self.learning_rate_ = learning_rate
        nbw_modellist = NbwDenseModelList(self.alpha_, params_nbw)
        for _i_mdl_path in range(len(previous_nbw_model_filePath_list)):
            pre_nbw_mdl_fpath = previous_nbw_model_filePath_list[_i_mdl_path]
            if os.path.exists(pre_nbw_mdl_fpath):
                nbw_model_parameters = torch.load(pre_nbw_mdl_fpath)
                nbw_modellist.load_and_add_model(nbw_model_parameters)
        nbw_modellist.add_model()
        self.nbw_models_ = nbw_modellist

    def forward(self, data, data_shuffled):
        return self.nbw_models_(data,
                                data_shuffled,
                                pred_mode='train')

    def training_step(self, batch, batch_idx: int) -> torch.Tensor:
        data, data_shuffled = batch
        adjust_prob_P, adjust_prob_Q = self(data, data_shuffled)
        term_P = torch.mean(adjust_prob_P)
        term_Q = torch.mean(adjust_prob_Q)
        loss = term_P/self.alpha_ + term_Q/(1 - self.alpha_)
        return loss

    def validation_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            data, data_shuffled = batch
            adjust_prob_P, adjust_prob_Q = self(data, data_shuffled)
            term_P = torch.mean(adjust_prob_P)
            term_Q = torch.mean(adjust_prob_Q)
            loss = term_P/self.alpha_ + term_Q/(1 - self.alpha_)
            alpha_info = 1/(self.alpha_*(1 - self.alpha_)) - loss
            self.log(f'validate_alpha_infomation', alpha_info,
                     on_step=True,
                     on_epoch=True)
            return alpha_info

    def test_step(self, batch, batch_idx: int) -> None:
        with torch.inference_mode():
            data, data_shuffled = batch
            adjust_prob_P, adjust_prob_Q = self(data, data_shuffled)
            term_P = torch.mean(adjust_prob_P)
            term_Q = torch.mean(adjust_prob_Q)
            loss = term_P/self.alpha_ + term_Q/(1 - self.alpha_)
            alpha_info = 1/(self.alpha_*(1 - self.alpha_)) - loss
            self.log(f'test_alpha_infomation', alpha_info,
                     on_step=True,
                     on_epoch=True)

    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.nbw_models_.parameters(),
                          lr=self.learning_rate_)


def load_nbw_models(
            alpha: float,
            params_nbw: dict,
            previous_nbw_model_filePath_list: List[str]
        ) -> NbwDenseModelList:
    nbw_modellist = NbwDenseModelList(alpha, params_nbw)
    for _i_mdl_path in range(len(previous_nbw_model_filePath_list)):
        pre_nbw_mdl_fpath = previous_nbw_model_filePath_list[_i_mdl_path]
        if os.path.exists(pre_nbw_mdl_fpath):
            nbw_model_parameters = torch.load(pre_nbw_mdl_fpath)
            nbw_modellist.load_and_add_model(nbw_model_parameters)

    return nbw_modellist


def build_new_nbw_model(
            alpha: float,
            train_explanatories_to_be_balanced: List[Any],
            eval_explanatories_to_be_balanced: List[Any],
            params_nbw: dict,
            learning_rate: float,
            batchsize: int,
            pytorch_trainer: pl.Trainer,
            previous_nbw_model_filePath_list: List[str]
        ) -> NbwDense:
    # Convert numpy arrays to pytoch tensor
    train_explanatories_tensor_list = []
    for _i_exp in range(len(train_explanatories_to_be_balanced)):
        tmp = train_explanatories_to_be_balanced[_i_exp]
        # Note: For using pytorch, the type of numpy arrays need
        # to be float32.
        i_exp_for_tensor = tmp.astype(np.float32)
        i_exp_tensor = torch.from_numpy(i_exp_for_tensor)
        i_exp_tensor.requires_grad_(True)
        train_explanatories_tensor_list.append(i_exp_tensor)

    train_explanatories_shuffled_tensor_list = []
    for _i_data in range(len(train_explanatories_tensor_list)):
        i_exp_tensor = train_explanatories_tensor_list[_i_data]
        r = torch.randperm(i_exp_tensor.size()[0])
        i_exp_tensor_suffuled = i_exp_tensor[r, :]
        train_explanatories_shuffled_tensor_list.append(
          i_exp_tensor_suffuled.detach())
    train_explanatories_tsr = torch.cat(
        train_explanatories_tensor_list, dim=1)
    train_explanatories_shuffled_tsr = torch.cat(
          train_explanatories_shuffled_tensor_list, dim=1)

    # Create data for pytorch_lightning
    train_dataset_all_expls_for_nbw = torch.utils.data.TensorDataset(
        train_explanatories_tsr, train_explanatories_shuffled_tsr)

    train_dataloaer_nbw = DataLoader(
        train_dataset_all_expls_for_nbw,
        shuffle=True,
        batch_size=batchsize)

    # Convert numpy arrays to pytoch tensor
    eval_explanatories_tensor_list = []
    for _i_exp in range(len(eval_explanatories_to_be_balanced)):
        tmp = eval_explanatories_to_be_balanced[_i_exp]
        # Note: For using pytorch, the type of numpy arrays need
        # to be float32.
        i_exp_for_tensor = tmp.astype(np.float32)
        i_exp_tensor = torch.from_numpy(i_exp_for_tensor)
        i_exp_tensor.requires_grad_(True)
        eval_explanatories_tensor_list.append(i_exp_tensor)

    eval_explanatories_shuffled_tensor_list = []
    for _i_data in range(len(eval_explanatories_tensor_list)):
        i_exp_tensor = eval_explanatories_tensor_list[_i_data]
        r = torch.randperm(i_exp_tensor.size()[0])
        i_exp_tensor_suffuled = i_exp_tensor[r, :]
        eval_explanatories_shuffled_tensor_list.append(
          i_exp_tensor_suffuled.detach())
    eval_explanatories_tsr = torch.cat(
        eval_explanatories_tensor_list, dim=1)
    eval_explanatories_shuffled_tsr = torch.cat(
          eval_explanatories_shuffled_tensor_list, dim=1)
    n_row_data = eval_explanatories_tsr.size()[0]

    # Create data for pytorch_lightning
    eval_dataset_all_expls_for_nbw = torch.utils.data.TensorDataset(
        eval_explanatories_tsr, eval_explanatories_shuffled_tsr)

    eval_dataloaer_nbw = DataLoader(
        eval_dataset_all_expls_for_nbw,
        shuffle=True,
        batch_size=n_row_data)

    input_dim_nbw = train_explanatories_tsr.shape[1]
    params_nbw['input_dim'] = input_dim_nbw

    # Build a nbw model
    nbw_models = LightningNbwDense(
                        alpha,
                        params_nbw,
                        learning_rate,
                        previous_nbw_model_filePath_list)
    pytorch_trainer.fit(
        nbw_models,
        train_dataloaer_nbw,
        eval_dataloaer_nbw)
    pytorch_trainer.test(nbw_models, eval_dataloaer_nbw)

    new_nbw_model = nbw_models.nbw_models_.new_nbw_model_

    return new_nbw_model


def train_and_enhance_NBW(
            alpha: float,
            train_explanatories_to_be_balanced: List[Any],
            eval_explanatories_to_be_balanced: List[Any],
            params_nbw: dict,
            n_enhance_nbw: int,
            learning_rate: float,
            batchsize: int,
            n_epoch_model: int,
            n_epoch_estimate: int,
            out_results_dir: str
        ) -> [List[str], float, List[str]]:
    out_log_dir = os.path.join(
        out_results_dir, 'pytorch_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, 'nbw_models')
    os.makedirs(out_model_dir, exist_ok=True)
    nrow_data = train_explanatories_to_be_balanced[0].shape[0]

    alpha_infos_of_all_nbw_models = dict()
    filePaths_of_nbw_models_to_use_list = []
    best_i_nbw = -1
    current_best_alpha_infomation_estimated = np.inf
    previous_nbw_model_filepath = None
    for _i_nbw in range(0, n_enhance_nbw + 1):
        # FilePaths for building a new NBW model
        filePaths_of_nbw_models_for_modeling = \
            filePaths_of_nbw_models_to_use_list.copy()

        # Build a NGB model
        now = datetime.datetime.now()
        logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
        logname = f'modeling({_i_nbw})_{logname_suffix}'
        logger = TensorBoardLogger(
            out_log_dir,
            logname,
            default_hp_metric=False)
        trainer_build = pl.Trainer(
            accelerator='gpu',
            devices=2,
            strategy='dp',
            logger=logger,
            max_epochs=n_epoch_model,
            enable_progress_bar=False)

        new_nbw_model = build_new_nbw_model(
            alpha,
            train_explanatories_to_be_balanced,
            eval_explanatories_to_be_balanced,
            params_nbw,
            learning_rate,
            batchsize,
            trainer_build,
            filePaths_of_nbw_models_for_modeling)

        # Save the above results
        out_nbw_mdl_filename = f'nbw_model_{_i_nbw:02}.mi'
        out_nbw_mdl_filepath = os.path.join(
            out_model_dir, out_nbw_mdl_filename)
        torch.save(new_nbw_model.state_dict(),
                   out_nbw_mdl_filepath)

        # Estimate the alpha information for the new NGB model
        now = datetime.datetime.now()
        logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
        logname = f'estimating({_i_nbw})_{logname_suffix}'
        logger = TensorBoardLogger(
            out_log_dir,
            logname,
            default_hp_metric=False)
        alphaInfo_tracker = AlphaInfosTracker()
        trainer_estimate = pl.Trainer(
            accelerator='gpu',
            devices=2,
            strategy='dp',
            logger=logger,
            max_epochs=n_epoch_estimate,
            callbacks=alphaInfo_tracker,
            enable_progress_bar=False)

        _ = build_new_nbw_model(
            alpha,
            train_explanatories_to_be_balanced,
            eval_explanatories_to_be_balanced,
            params_nbw,
            learning_rate,
            batchsize,
            trainer_estimate,
            filePaths_of_nbw_models_for_modeling + [out_nbw_mdl_filepath])

        all_records_alpha_infomation = alphaInfo_tracker.collection
        alpha_infomation_estimated = np.max(all_records_alpha_infomation)

        alpha_infos_of_all_nbw_models[out_nbw_mdl_filepath] = \
            alpha_infomation_estimated

        # Logging
        trainer_build.logger.log_hyperparams(params_nbw)
        trainer_build.logger.log_hyperparams(
            params_nbw,
            {'max_alpha_infomation': alpha_infomation_estimated})
        if n_enhance_nbw >= 0:
            if _i_nbw > 0:
                if alpha_infomation_estimated > 0 \
                   and alpha_infomation_estimated \
                        < current_best_alpha_infomation_estimated*0.99:

                    best_i_nbw = _i_nbw
                    current_best_alpha_infomation_estimated = \
                        alpha_infomation_estimated
                    filePaths_of_nbw_models_to_use_list += \
                        [out_nbw_mdl_filepath]
                else:
                    break
            else:
                best_i_nbw = 0
                current_best_alpha_infomation_estimated = \
                    alpha_infomation_estimated
                filePaths_of_nbw_models_to_use_list += \
                    [out_nbw_mdl_filepath]
        else:
            # n_enhance_nbw = -1
            filePaths_of_nbw_models_to_use_list += \
                        [out_nbw_mdl_filepath]

        print('--------------------------------------------------------------')
        print(f'Number of trials to enhacne balancing: {_i_nbw}')
        print(f'Current Best (smallest alpha-information) = ',
              f'{current_best_alpha_infomation_estimated}, {best_i_nbw}')
        print('--------------------------------------------------------------')

    return [filePaths_of_nbw_models_to_use_list,
            current_best_alpha_infomation_estimated,
            alpha_infos_of_all_nbw_models]


def train_NBW(
            alpha: float,
            train_explanatories_to_be_balanced: List[Any],
            eval_explanatories_to_be_balanced: List[Any],
            do_estimate_alpha_inf: bool,
            params_nbw: dict,
            learning_rate: float,
            batchsize: int,
            n_epoch_model: int,
            out_results_dir: str
        ) -> [List[str], float, List[str]]:
    out_log_dir = os.path.join(
        out_results_dir, 'pytorch_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, 'nbw_models')
    os.makedirs(out_model_dir, exist_ok=True)
    nrow_data = train_explanatories_to_be_balanced[0].shape[0]

    alpha_infos_of_all_nbw_models = dict()
    best_i_nbw = -1
    current_best_alpha_infomation_estimated = np.inf
    previous_nbw_model_filepath = None

    # Build a NGB model
    _i_nbw = 0
    now = datetime.datetime.now()
    logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
    logname = f'modeling({_i_nbw})_{logname_suffix}'
    logger = TensorBoardLogger(
        out_log_dir,
        logname,
        default_hp_metric=False)
    trainer_build = pl.Trainer(
        accelerator='gpu',
        devices=2,
        strategy='dp',
        logger=logger,
        max_epochs=n_epoch_model,
        enable_progress_bar=False)

    new_nbw_model = build_new_nbw_model(
        alpha,
        train_explanatories_to_be_balanced,
        eval_explanatories_to_be_balanced,
        params_nbw,
        learning_rate,
        batchsize,
        trainer_build,
        [])

    # Save the above results
    out_nbw_mdl_filename = f'nbw_model_{_i_nbw:02}.mi'
    out_nbw_mdl_filepath = os.path.join(
        out_model_dir, out_nbw_mdl_filename)
    torch.save(new_nbw_model.state_dict(),
               out_nbw_mdl_filepath)

    filePaths_of_nbw_models_to_use_list = [out_nbw_mdl_filepath]

    # Estimate the alpha information for the new NGB model
    if do_estimate_alpha_inf:
        now = datetime.datetime.now()
        logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
        logname = f'estimating({_i_nbw})_{logname_suffix}'
        logger = TensorBoardLogger(
            out_log_dir,
            logname,
            default_hp_metric=False)
        alphaInfo_tracker = AlphaInfosTracker()
        trainer_estimate = pl.Trainer(
            accelerator='gpu',
            devices=2,
            strategy='dp',
            logger=logger,
            max_epochs=n_epoch_estimate,
            callbacks=alphaInfo_tracker,
            enable_progress_bar=False)

        _ = build_new_nbw_model(
            alpha,
            train_explanatories_to_be_balanced,
            eval_explanatories_to_be_balanced,
            params_nbw,
            learning_rate,
            batchsize,
            trainer_estimate,
            [out_nbw_mdl_filepath])

        all_records_alpha_infomation = alphaInfo_tracker.collection
        alpha_infomation_estimated = np.max(all_records_alpha_infomation)

        # Logging
        trainer_build.logger.log_hyperparams(params_nbw)
        trainer_build.logger.log_hyperparams(
            params_nbw,
            {'max_alpha_infomation': alpha_infomation_estimated})
    else:
        alpha_infomation_estimated = trainer_build.callback_metrics[
            'test_alpha_infomation'].item()

    alpha_infos_of_all_nbw_models = {
        out_nbw_mdl_filepath: alpha_infomation_estimated}

    return [filePaths_of_nbw_models_to_use_list,
            alpha_infomation_estimated,
            alpha_infos_of_all_nbw_models]


def estimate_balancing_weights(
      explanatories_for_prediction: List[Any],
      params_nbw: dict,
      filePaths_of_nbw_models_to_use_list: int
  ) -> Any:

    train_expls_mat = np.concatenate(explanatories_for_prediction, axis=1)
    tmp = train_expls_mat.astype(np.float32)
    train_expls_tsr = torch.from_numpy(tmp)

    input_dim_nbw = train_expls_tsr.shape[1]
    params_nbw['input_dim'] = input_dim_nbw

    nbw_models = load_nbw_models(
          1/2,  # dummy, not used
          params_nbw,
          filePaths_of_nbw_models_to_use_list)

    esti_nbw_tsr = nbw_models(
        train_expls_tsr, None, 'estimate_only')
    esti_nbw_mat = esti_nbw_tsr.numpy()
    result_weights = esti_nbw_mat.flatten()

    return result_weights
