import os
import pickle
import datetime

from typing import List, Any, Tuple, Dict
import numpy as np

from scipy import linalg
import numpy as np

import torch
from torch import Tensor
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
import logging

logging.getLogger("lightning").setLevel(logging.ERROR)

from lib.DRE.NaiveDRE import *
from lib.DensityRateEstimation import estimate_alpha_div
from lib.DensityRateEstimation import estimate_KL_div, estimate_JS_div
from lib.DensityRateEstimation import seed_everything

def _dre_train_for_one_epoch(
              model_to_train: LightningNaiveDREdenseNN,
              train_denominator_tsr: Tensor,
              train_numerator_tsr: Tensor, 
              #eval_denominator_tsr: Tensor,
              #eval_numerator_tsr: Tensor,
              batchsize: int,
              out_log_dir,
              device_id,
          ) -> None:
    now = datetime.datetime.now()
    if device_id is None:
        device_to_use = torch.device('cuda')
    else:
        device_to_use = torch.device(f'cuda:{device_id}')
    train_dataset = torch.utils.data.TensorDataset(
        train_denominator_tsr, train_numerator_tsr)
    train_dataloaer = DataLoader(
        train_dataset,
        shuffle=True,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize)
    logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
    logname = f'Run_{logname_suffix}'
    logger = TensorBoardLogger(
            out_log_dir,
            logname,
            default_hp_metric=False)
    pytorch_trainer_for_one_epoch = pl.Trainer(
            accelerator='gpu',
            devices=[device_id],
            strategy='auto',
            logger=logger,
            max_epochs=1,
            enable_progress_bar=False)
    #eval_dataset = torch.utils.data.TensorDataset(
    #    eval_denominator_tsr, eval_numerator_tsr)
    # eval_dataloaer = DataLoader(
    #     eval_dataset,
    #     shuffle=True,
    #     batch_size=batchsize)

    # train
    pytorch_trainer_for_one_epoch.fit(
        model_to_train,
        train_dataloaer)
        #eval_dataloaer)
    #pytorch_trainer_for_one_epoch.test(model_to_train, eval_dataloaer)   


def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            #eval_denominator_data_np: Any,
            #eval_numerator_data_np: Any,
            #test_denominator_data_np: Any,
            #test_numerator_data_np: Any,
            params_training: Dict[str, Any],
            learning_rate: float,
            batchsize: int,
            n_all_epochs: int,
            #n_all_steps: int,
            n_points_to_plot: int,
            out_results_dir: str,
            do_save_model: bool,
            seed: int,
            device_id: int
            ) -> Dict[int, Dict[str, float]]:
    seed_everything(seed)  
    out_log_dir = os.path.join(out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, '_energy_models')
    os.makedirs(out_model_dir, exist_ok=True)

    if device_id is None:
        device_to_use = torch.device('cuda')
    else:
        device_to_use = torch.device(f'cuda:{device_id}')

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32)
        ).to(device_to_use)
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32)).to(device_to_use)
    
    # Because we want to see whether learning will converge,
    # training data will be used as test data.
    test_denominator_tsr = train_denominator_tsr
    test_numerator_tsr = train_numerator_tsr
    # eval_denominator_tsr = torch.from_numpy(
    #     eval_denominator_data_np.astype(np.float32))
    # eval_numerator_tsr = torch.from_numpy(
    #     eval_numerator_data_np.astype(np.float32))
    #test_denominator_tsr = torch.from_numpy(
    #     test_denominator_data_np.astype(np.float32)).type_as(
    #                       train_numerator_tsr)
    #test_numerator_tsr = torch.from_numpy(
    #     test_numerator_data_np.astype(np.float32)).type_as(
    #                       train_numerator_tsr)
    test_dataset = torch.utils.data.TensorDataset(
        test_denominator_tsr, test_numerator_tsr)
    test_dataloaer = DataLoader(
        test_dataset,
        shuffle=True,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize) 

    n_rows_data = train_denominator_tsr.shape[0]
    n_dims_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dims_input
    model_to_train = LightningDREdenseNN(
        params_training,
        learning_rate)

    res_dict = dict()
    n_all_steps_per_epoch = n_rows_data // batchsize
    n_points_per_epoch = n_points_to_plot // n_all_epochs
    n_steps_per_plot = n_all_steps_per_epoch // n_points_per_epoch
    #n_steps_per_plot = n_rows_data // n_points_to_plot
    n_train_data_per_plot = n_steps_per_plot * batchsize
    #all_n_train = n_all_steps // n_step_plot
    for _i_epoch in range(n_all_epochs):
        for _i_train in range(n_points_per_epoch): 

            start_row = _i_train * n_train_data_per_plot
            end_row= (_i_train + 1) * n_train_data_per_plot

            train_dataset = torch.utils.data.TensorDataset(
                train_denominator_tsr[start_row:end_row, :],
                train_numerator_tsr[start_row:end_row, :])
            train_dataloaer = DataLoader(
                train_dataset,
                shuffle=True,
                generator=torch.Generator(device=device_to_use),
                batch_size=batchsize)
            
            now = datetime.datetime.now()
            logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
            logname = f'Run_{logname_suffix}'
            logger = TensorBoardLogger(
                    out_log_dir,
                    logname,
                    default_hp_metric=False)

            pytorch_trainer_for_one_epoch = pl.Trainer(
                    accelerator='gpu',
                    devices=[device_id],
                    strategy='auto',
                    logger=logger,
                    max_epochs=1,
                    enable_progress_bar=False)
            
            pytorch_trainer_for_one_epoch.fit(
                model_to_train,
                train_dataloaer)
            
            #pytorch_trainer_for_one_epoch.test(
            #    model_to_train,
            #    test_dataloaer)
            # _dre_train_for_one_epoch(
            #     model_to_train,
            #     train_denominator_tsr[start_row:end_row, :],
            #     train_numerator_tsr[start_row:end_row, :],
            #     #eval_denominator_tsr,
            #     #eval_numerator_tsr,
            #     batchsize,
            #     out_log_dir,
            #     device_id)
            #n_step_per_epoch = n_rows_data // batchsize
            #pred_T_P, pred_T_Q = pytorch_trainer_for_one_epoch.predict(
            #    model_to_train,
            #    test_dataloaer)
                #test_denominator_tsr, test_numerator_tsr)

                #model_to_train,
                #test_dataloaer)


            prob_model = model_to_train.prob_rate_model_.to(device_to_use)

            test_alpha_div = estimate_alpha_div(
                params_training['alpha'],
                #pred_T_P,
                #pred_T_Q)
                test_denominator_tsr, 
                test_numerator_tsr,
                model_to_train.prob_rate_model_)

            test_KL_div_estmated = estimate_KL_div(
                 test_denominator_tsr, 
                 model_to_train.prob_rate_model_)

            test_JS_div_estmated = estimate_JS_div(
                 test_denominator_tsr, 
                 test_numerator_tsr,
                 model_to_train.prob_rate_model_)

            estimated_divs = {
                'alpha_divergence':test_alpha_div,
                 'KL_divergence':test_KL_div_estmated,
                 'JS_divergence':test_JS_div_estmated
            }
            res_dict[_i_train*n_steps_per_plot] = estimated_divs

    if do_save_model:
        # Save the above results
        out_energy_mdl_filename = f'energy_model.mdl'
        out_energy_mdl_filepath = os.path.join(
            out_model_dir, out_energy_mdl_filename)
        torch.save(
            model_to_train.prob_rate_model_.energy_model_.state_dict(),
            out_energy_mdl_filepath)

    return res_dict
