import os
import pickle
import datetime

from typing import List, Any, Tuple, Dict
import numpy as np

from scipy import linalg
import numpy as np

import torch
from torch import Tensor
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from lib.ProbabilityRateModeling import *
from lib.DensityRateEstimation import estimate_alpha_div
from lib.DensityRateEstimation import estimate_KL_div, estimate_JS_div
from lib.DensityRateEstimation import seed_everything

def _dre_train_for_one_epoch(
              model_to_train: LightningDREdenseNN,
              train_denominator_tsr: Tensor,
              train_numerator_tsr: Tensor, 
              pytorch_trainer_for_one_epoch: pl.Trainer,
              batchsize: int,
          ) -> None:
    train_dataset = torch.utils.data.TensorDataset(
        train_denominator_tsr, train_numerator_tsr)
    train_dataloaer = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=batchsize)

    # train
    pytorch_trainer_for_one_epoch.fit(
        model_to_train,
        train_dataloaer)



def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            params_training: Dict[str, Any],
            learning_rate: float,
            batchsize: int,
            n_all_epochs: int,
            out_results_dir: str,
            do_save_moel: bool,
            seed: int
            ) -> Dict[int, Dict[str, float]]:
    seed_everything(seed)  
    out_log_dir = os.path.join(out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, '_energy_models')
    os.makedirs(out_model_dir, exist_ok=True)

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32))
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32))
    # Because we want to see whether learning will converge,
    # training data will be used as test data.
    test_denominator_tsr = train_denominator_tsr
    test_numerator_tsr = train_numerator_tsr

    n_rows_data = train_denominator_tsr.shape[0]
    n_dims_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dims_input
    model_to_train = LightningDREdenseNN(
        params_training,
        learning_rate)

    now = datetime.datetime.now()
    res_dict = dict()
    for _i_eph in range(n_all_epochs): 
        logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
        logname = f'Run_{logname_suffix}'
        logger = TensorBoardLogger(
            out_log_dir,
            logname,
            default_hp_metric=False)
        pytorch_trainer_for_one_epoch = pl.Trainer(
            accelerator='gpu',
            devices=1,
            strategy='auto',
            logger=logger,
            max_epochs=1,
            enable_progress_bar=False)
        _dre_train_for_one_epoch(
            model_to_train,
            train_denominator_tsr,
            train_numerator_tsr,
            pytorch_trainer_for_one_epoch,
            batchsize)
        n_step_per_epoch = n_rows_data // batchsize

        test_alpha_div = estimate_alpha_div(
            params_training['alpha'],
            test_denominator_tsr, 
            test_numerator_tsr,
            model_to_train.prob_rate_model_)

        test_KL_div_estmated = estimate_KL_div(
            test_denominator_tsr, 
            model_to_train.prob_rate_model_)

        test_JS_div_estmated = estimate_JS_div(
            test_denominator_tsr, 
            test_numerator_tsr,
            model_to_train.prob_rate_model_)

        estimated_divs = {
            'alpha_divergence':test_alpha_div,
            'KL_divergence':test_KL_div_estmated,
            'JS_divergence':test_JS_div_estmated
        }
        res_dict[_i_eph*n_step_per_epoch] = estimated_divs

    if do_save_moel:
        # Save the above results
        out_energy_mdl_filename = f'energy_model.mdl'
        out_energy_mdl_filepath = os.path.join(
            out_model_dir, out_energy_mdl_filename)
        torch.save(
            model_to_train.prob_rate_model_.energy_model_.state_dict(),
            out_energy_mdl_filepath)

    return res_dict
