import os
import random
import datetime

from typing import List, Any, Tuple, Dict

import numpy as np

from sklearn.metrics import mean_squared_error

import torch
from torch import nn, Tensor
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from lib.dre.common.util import seed_everything
from lib.dre.common.estimate import gen_estimate_density_rate_func, gen_estimate_divergence_func
from lib.dre.lightning.linear_zero_weights import LightningNaiveDREdenseNN




torch.set_default_tensor_type('torch.cuda.FloatTensor')


def dre_train_for_all_epoch(
            train_denominator_data_np: Any,
            train_numerator_data_np: Any, 
            eval_denominator_data_np: Any,
            eval_numerator_data_np: Any,
            test_denominator_data_np: Any,
            test_numerator_data_np: Any,
            params_training: Dict[str, Any],
            batchsize: int,
            n_all_epochs: int,
            method: str,
            all_points_to_monitor: int,
            out_results_dir: str,
            do_save_result: bool,
            seed: int,
            device_id: int = None,
            true_rate_for_test: Any  = None,
            true_rate_for_eval: Any  = None
            ) ->  List[float]:
    seed_everything(seed)
    out_log_dir = os.path.join(out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, 'energy_models')
    os.makedirs(out_model_dir, exist_ok=True)

    if device_id is None:
        device_to_use = torch.device('cuda')
    else:
        device_to_use = torch.device(f'cuda:{device_id}')

    train_denominator_tsr = torch.from_numpy(
        train_denominator_data_np.astype(np.float32)
        ).to(device_to_use)
    train_numerator_tsr = torch.from_numpy(
        train_numerator_data_np.astype(np.float32)).type_as(
                          train_denominator_tsr)
    eval_denominator_tsr = torch.from_numpy(
        eval_denominator_data_np.astype(np.float32)).type_as(
                          train_denominator_tsr)
    eval_numerator_tsr = torch.from_numpy(
        eval_numerator_data_np.astype(np.float32)).type_as(
                          train_denominator_tsr)
    test_denominator_tsr = torch.from_numpy(
        test_denominator_data_np.astype(np.float32)).type_as(
                          train_denominator_tsr)
    test_numerator_tsr = torch.from_numpy(
        test_numerator_data_np.astype(np.float32)).type_as(
                          train_denominator_tsr)
    train_dataset = torch.utils.data.TensorDataset(
        train_denominator_tsr, train_numerator_tsr)
    train_dataloaer = DataLoader(
        train_dataset,
        shuffle=True,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize)

    if true_rate_for_eval is not None:
        params_training['eval_mse'] = True
        true_rate_for_eval_tsr = torch.from_numpy(
            true_rate_for_eval.astype(np.float32)).type_as(
                            train_denominator_tsr)
        eval_dataset = torch.utils.data.TensorDataset(
            eval_denominator_tsr,
            eval_numerator_tsr,
            true_rate_for_eval_tsr)
    else:
        params_training['eval_mse'] = False
        eval_dataset = torch.utils.data.TensorDataset(
            eval_denominator_tsr, eval_numerator_tsr)

    if true_rate_for_test is not None:
        params_training['test_mse'] = True
        true_rate_for_test_tsr = torch.from_numpy(
            true_rate_for_test.astype(np.float32)).type_as(
                            train_denominator_tsr)
        test_dataset = torch.utils.data.TensorDataset(
            test_denominator_tsr,
            test_numerator_tsr,
            true_rate_for_test_tsr)
    else:
        params_training['test_mse'] = False
        test_dataset = torch.utils.data.TensorDataset(
            test_denominator_tsr, test_numerator_tsr)   

    eval_dataloaer = DataLoader(
        eval_dataset,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize)

    test_dataloaer = DataLoader(
        test_dataset,
        generator=torch.Generator(device=device_to_use),
        batch_size=batchsize) 

    n_rows_data = train_denominator_tsr.shape[0]
    n_dim_input = train_denominator_tsr.shape[1]
    params_training['input_dim'] = n_dim_input
    model_to_train = LightningNaiveDREdenseNN(
        method,
        params_training)
       
    res_dict = dict()
    n_all_steps_per_epoch = n_rows_data // batchsize
    n_points_to_monitor_per_epoch = all_points_to_monitor // n_all_epochs
    n_steps_per_monitoring = n_all_steps_per_epoch // n_points_to_monitor_per_epoch
    #n_steps_per_plot = n_rows_data // n_steps_per_monitoring
    n_train_data_per_monitoring  = n_steps_per_monitoring * batchsize
    current_step = 0
    for _i_epoch in range(n_all_epochs):
        for _i_point_to_monitor in range(n_points_to_monitor_per_epoch): 
            start_row = _i_point_to_monitor * n_train_data_per_monitoring
            end_row= (_i_point_to_monitor + 1) * n_train_data_per_monitoring

            train_dataset = torch.utils.data.TensorDataset(
                train_denominator_tsr[start_row:end_row, :],
                train_numerator_tsr[start_row:end_row, :])
            train_dataloaer = DataLoader(
                train_dataset,
                shuffle=True,
                generator=torch.Generator(device=device_to_use),
                batch_size=batchsize)

            now = datetime.datetime.now()
            logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
            logname = f'Run_{logname_suffix}'
            logger = TensorBoardLogger(
                    out_log_dir,
                    logname,
                    default_hp_metric=False)

            pytorch_trainer_for_one_epoch = pl.Trainer(
                    accelerator='gpu',
                    devices=[device_id],
                    strategy='auto',
                    logger=logger,
                    max_epochs=1,
                    enable_progress_bar=False)
            
            pytorch_trainer_for_one_epoch.fit(
                model_to_train,
                train_dataloaer)
            
            prob_model = model_to_train.model_.to(device_to_use)
    
            esti_divergence_func = gen_estimate_divergence_func(
                method, params_training)
            esti_density_rate_func = gen_estimate_density_rate_func(method)

            estimated_div = esti_divergence_func(
                        test_denominator_tsr, test_numerator_tsr,
                        prob_model)
            current_train_sttats_dict = {
                'target_divergence':estimated_div}
            if true_rate_for_test is not None:
                test_estimated_rate_tsr = esti_density_rate_func(
                                    test_denominator_tsr, prob_model)
                test_estimated_rate_np = test_estimated_rate_tsr.cpu().detach().numpy()
                mse = mean_squared_error(
                    test_estimated_rate_np, true_rate_for_test)
                bias = np.mean(test_estimated_rate_np - true_rate_for_test)
                current_train_sttats_dict['mse'] = mse
                current_train_sttats_dict['bias'] = mse
                max_pred_rate = np.max(test_estimated_rate_np)
                current_train_sttats_dict['max_pred_rate'] = max_pred_rate
                min_pred_rate = np.min(test_estimated_rate_np)
                current_train_sttats_dict['min_pred_rate'] = min_pred_rate
                median_pred_rate = np.median(test_estimated_rate_np)
                current_train_sttats_dict['median_pred_rate'] = median_pred_rate
            current_step += n_steps_per_monitoring
            res_dict[current_step] = current_train_sttats_dict
             
            if do_save_result:
                # Save the above results
                out_pred_dir = os.path.join(out_results_dir, 'estimated')
                os.makedirs(out_pred_dir, exist_ok=True)
                np.save(
                    os.path.join(out_pred_dir, 'vals_estimated'),
                    test_estimated_rate_np)
                out_energy_mdl_filename = f'energy_model.mdl'
                out_energy_mdl_filepath = os.path.join(
                    out_model_dir, out_energy_mdl_filename)
                torch.save(
                    model_to_train.model_.model_list_.state_dict(),
                    out_energy_mdl_filepath)

    return res_dict


