import os
import datetime

from typing import List, Any, Dict, Tuple

from scipy import linalg
import numpy as np

import torch
from torch import Tensor
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from .NeuralBW import LightningNbwDense, AlphaInfosTracker


def calc_true_alpha_info(
        alpha: float,
        rho: float,
        d: int) -> float:
    sigma_original_mat = np.repeat(rho, d*d).reshape(d, d)
    np.fill_diagonal(sigma_original_mat, 1)

    det_sigma_original = linalg.det(sigma_original_mat)
    det_sigma_balanced = 1

    inv_sigma_denominator = (
      alpha*linalg.inv(sigma_original_mat)
      + (1-alpha)*np.identity(d))
    det_sigma_denominator = 1/linalg.det(inv_sigma_denominator)

    true_alpha_div = 1/(alpha*(alpha-1))*(
      (1/det_sigma_original)**(alpha/2)
      * (1/det_sigma_balanced)**((1-alpha)/2)
      * det_sigma_denominator**(1/2) - 1)

    return true_alpha_div


def generate_data(
        n_data_train: List[Any],
        n_data_test: List[Any],
        rho: float,
        d: int) -> Tuple[List[Any], List[Any]]:
    myu_vec = np.repeat(0.0, d)
    sigma_mat = np.repeat(rho, d*d).reshape(d, d)
    np.fill_diagonal(sigma_mat, 1)
    train_data_mat = np.random.multivariate_normal(
        myu_vec, sigma_mat, n_data_train)
    test_data_mat = np.random.multivariate_normal(
        myu_vec, sigma_mat, n_data_test)
    train_explanatories_to_be_balanced = []
    test_explanatories_to_be_balanced = []
    for _i_row in range(d):
        train_explanatories_to_be_balanced.append(
            train_data_mat[:, [_i_row]])
        test_explanatories_to_be_balanced.append(
            test_data_mat[:, [_i_row]])

    return [train_explanatories_to_be_balanced,
            test_explanatories_to_be_balanced]


def train_for_one_epoch(
            model_to_train: LightningNbwDense,
            train_explanatories_tsr: Tensor,
            train_explanatories_shuffled_tsr: Tensor,
            eval_explanatories_tsr: Tensor,
            eval_explanatories_shuffled_tsr: Tensor,
            pytorch_trainer_for_one_epoch: pl.Trainer,
            batchsize: int,
        ) -> None:

    # Create data for pytorch_lightning
    train_dataset_all_expls_for_nbw = torch.utils.data.TensorDataset(
        train_explanatories_tsr.clone(),
        train_explanatories_shuffled_tsr.clone())
    train_dataset_all_expls_for_nbw
    train_dataloaer_nbw = DataLoader(
      train_dataset_all_expls_for_nbw,
      shuffle=True,
      batch_size=batchsize)

    eval_dataset_all_expls_for_nbw = torch.utils.data.TensorDataset(
      eval_explanatories_tsr.clone(),
      eval_explanatories_shuffled_tsr.clone())
    eval_dataloaer_nbw = DataLoader(
      eval_dataset_all_expls_for_nbw,
      shuffle=True,
      batch_size=batchsize)

    # train
    pytorch_trainer_for_one_epoch.fit(
        model_to_train,
        train_dataloaer_nbw,
        eval_dataloaer_nbw)
    pytorch_trainer_for_one_epoch.test(model_to_train, eval_dataloaer_nbw)


def train_for_all_epochs(
            alpha: float,
            train_explanatories_to_be_balanced: List[Any],
            eval_explanatories_to_be_balanced: List[Any],
            params_nbw: dict,
            learning_rate: float,
            batchsize: int,
            n_all_epochs: int,
            out_results_dir: str) -> Dict[int, float]:
    out_log_dir = os.path.join(
        out_results_dir, 'all_logs')
    os.makedirs(out_log_dir, exist_ok=True)
    out_model_dir = os.path.join(out_results_dir, '_nbw_models')
    os.makedirs(out_model_dir, exist_ok=True)

    N_row_data = train_explanatories_to_be_balanced[0].shape[0]

    # Convert numpy arrays to pytoch tensor
    train_explanatories_tensor_list = []
    for _i_exp in range(len(train_explanatories_to_be_balanced)):
        tmp = train_explanatories_to_be_balanced[_i_exp]
        # Note: For using pytorch, the type of numpy arrays need
        # to be float32.
        i_exp_for_tensor = tmp.astype(np.float32)
        i_exp_tensor = torch.from_numpy(i_exp_for_tensor)
        i_exp_tensor.requires_grad_(True)
        train_explanatories_tensor_list.append(i_exp_tensor)

    train_explanatories_shuffled_tensor_list = []
    for _i_data in range(len(train_explanatories_tensor_list)):
        i_exp_tensor = train_explanatories_tensor_list[_i_data]
        r = torch.randperm(i_exp_tensor.size()[0])
        i_exp_tensor_suffuled = i_exp_tensor[r, :]
        train_explanatories_shuffled_tensor_list.append(
          i_exp_tensor_suffuled.detach())
    train_explanatories_tsr = torch.cat(
        train_explanatories_tensor_list, dim=1)
    train_explanatories_shuffled_tsr = torch.cat(
          train_explanatories_shuffled_tensor_list, dim=1)

    # Convert numpy arrays to pytoch tensor
    eval_explanatories_tensor_list = []
    for _i_exp in range(len(eval_explanatories_to_be_balanced)):
        tmp = eval_explanatories_to_be_balanced[_i_exp]
        # Note: For using pytorch, the type of numpy arrays need
        # to be float32.
        i_exp_for_tensor = tmp.astype(np.float32)
        i_exp_tensor = torch.from_numpy(i_exp_for_tensor)
        i_exp_tensor.requires_grad_(True)
        eval_explanatories_tensor_list.append(i_exp_tensor)

    eval_explanatories_shuffled_tensor_list = []
    for _i_data in range(len(eval_explanatories_tensor_list)):
        i_exp_tensor = eval_explanatories_tensor_list[_i_data]
        r = torch.randperm(i_exp_tensor.size()[0])
        i_exp_tensor_suffuled = i_exp_tensor[r, :]
        eval_explanatories_shuffled_tensor_list.append(
          i_exp_tensor_suffuled.detach())
    eval_explanatories_tsr = torch.cat(
        eval_explanatories_tensor_list, dim=1)
    eval_explanatories_shuffled_tsr = torch.cat(
          eval_explanatories_shuffled_tensor_list, dim=1)

    """
    Build a NBW model
    """
    # Settings for pytorch lightning trainer
    now = datetime.datetime.now()
    logname_suffix = now.strftime('%Y-%m-%d_%H_%M_%S')
    logname = f'build_{logname_suffix}'
    logger = TensorBoardLogger(
        out_log_dir,
        logname,
        default_hp_metric=False)
    model_to_train = LightningNbwDense(
        alpha,
        params_nbw,
        learning_rate,
        [])
    res_dict = dict()
    for _i_eph in range(n_all_epochs):
        pytorch_trainer_for_one_epoch = pl.Trainer(
            accelerator='gpu',
            devices=2,
            strategy='dp',
            logger=logger,
            max_epochs=1,
            enable_progress_bar=False)
        train_for_one_epoch(
            model_to_train,
            train_explanatories_tsr,
            train_explanatories_shuffled_tsr,
            eval_explanatories_tsr,
            eval_explanatories_shuffled_tsr,
            pytorch_trainer_for_one_epoch,
            batchsize)

        n_step_per_epoch = N_row_data // batchsize
        adjust_prob_P, adjust_prob_Q = model_to_train(
            eval_explanatories_tsr, eval_explanatories_shuffled_tsr)

        loss = (
            torch.mean(adjust_prob_P)/alpha
            + torch.mean(adjust_prob_Q)/(1 - alpha)
        ).item()
        alpha_infomation = 1/(alpha*(1 - alpha)) - loss
        res_dict[_i_eph*n_step_per_epoch] = alpha_infomation

    return res_dict
