import os
import pickle
import datetime

from typing import List, Any, Tuple, Dict
import numpy as np

from scipy import linalg
import numpy as np

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger

from lib.dre.common.util import seed_everything
from lib.dre.common.estimate import gen_estimate_target_divergence_func
from lib.dre.lightning.linear import LightningNaiveDREdenseNN


torch.set_default_tensor_type('torch.cuda.FloatTensor')




def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            params_training: Dict[str, Any],
            batchsize: int,
            n_all_epochs: int,
            out_results_dir: str,
            do_save_result: bool,
            seed: int,
            device_id: int
            ) -> Dict[int, Dict[str, float]]:
    method = 'alphaDiv'
    seed_everything(seed)  
    out_log_dir = os.path.join(out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, '_energy_models')
    os.makedirs(out_model_dir, exist_ok=True)

    if device_id is None:
        device_to_use = torch.device('cuda')
    else:
        device_to_use = torch.device(f'cuda:{device_id}')

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32)
        ).to(device_to_use)
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32)).to(device_to_use)
    
    test_denominator_tsr = train_denominator_tsr
    test_numerator_tsr = train_numerator_tsr
    test_dataset = torch.utils.data.TensorDataset(
        test_denominator_tsr, test_numerator_tsr)
    test_dataloaer = DataLoader(
        test_dataset,
        shuffle=True,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize) 
    n_rows_data = train_denominator_tsr.shape[0]
    n_steps_per_epoch = n_rows_data // batchsize
    n_dims_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dims_input
    params_training['eval_mse'] = False
    params_training['test_mse'] = False
    model_to_train = LightningNaiveDREdenseNN(
        method,
        params_training)
    esti_divergence_func = gen_estimate_target_divergence_func(
        method, params_training)

    now = datetime.datetime.now()
    res_dict = dict()
    for _i_train in range(n_all_epochs): 
            train_dataset = torch.utils.data.TensorDataset(
                train_denominator_tsr, train_numerator_tsr)
            train_dataloaer = DataLoader(
                train_dataset,
                shuffle=True,
                generator=torch.Generator(device=device_to_use),
                batch_size=batchsize)
            logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
            logname = f'Run_{logname_suffix}'
            logger = TensorBoardLogger(
                    out_log_dir,
                    logname,
                    default_hp_metric=False)
            pytorch_trainer_for_one_epoch = pl.Trainer(
                    accelerator='gpu',
                    devices=[device_id],
                    strategy='auto',
                    logger=logger,
                    max_epochs=1,
                    enable_progress_bar=False)     
            pytorch_trainer_for_one_epoch.fit(
                model_to_train,
                train_dataloaer)
            model_trained = model_to_train.model_.to(device_to_use)
            test_alpha_div = esti_divergence_func(
                test_denominator_tsr, test_numerator_tsr,
                model_trained)  

            estimated_divs = {
                'alpha_divergence':test_alpha_div,
            }

            res_dict[_i_train*n_steps_per_epoch] = estimated_divs

    if do_save_result:
        # Save the above results
        out_energy_mdl_filename = f'energy_model.mdl'
        out_energy_mdl_filepath = os.path.join(
            out_model_dir, out_energy_mdl_filename)
        torch.save(
            model_to_train.prob_rate_model_.energy_model_.state_dict(),
            out_energy_mdl_filepath)

    return res_dict
