# -*- coding: utf-8 -*-

from typing import Dict

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from tqdm import tqdm

from common.constants import ACCURACY, CHECKPOINT_PATH, CHUNK_SIZE, CLASSIFICATION, \
    CURRENT_MODEL_NAME, DEVICE, EPOCHS, experiment_cols, ML_UTILITY_OUT_PATH, LATENT_SPACE_SIZE, \
    LEARNING_RATE, TRAIN_LOAD_FROM_CHECKPOINT,  MACRO_F1, MASK_COL_NUMBER, MASK_COL_NUMBER_CONSTANT, \
    MAX_COL_TO_MASK, MIN_COL_TO_MASK, MSE, PROMPT, TIMESTEPS, TRAIN_BATCH_SIZE, \
    TEST_BATCH_SIZE, WEIGHTED_F1
from common.utils import DatasetMetadata
from experiments.utils_experiments import data_to_supervised_model_format, test_ml, train_xgboost
from tabular_diffusion.denoising_models.fast_data_loader_with_binary_mask import FastTensorDataLoaderWithBinaryMask
from tabular_diffusion.diffusion_manager import DiffusionManager


def train_one_epoch(diffusion_manager: DiffusionManager,
                    denoising_model: torch.nn.Module,
                    optimizer: optim.Optimizer,
                    data_loader: FastTensorDataLoaderWithBinaryMask,
                    epoch: int,
                    epochs: int,
                    device: str,
                    scheduler_iter=None,
                    scheduler_epoch=None,
                    clip_value: float=None,
                    clip_norm: float=None,
                    target_distribution: torch.tensor=None,
                    disable_progress_bar: bool=False) -> Dict:
    """Function for 1 epoch train

    :param diffusion_manager: DiffusionManager, Object managing the diffusion processes (forward - denoising)
    :param denoising_model: torch.nn.Module, the NNET to use to "remove" the noise
    :param optimizer: optim.optimizer, the optimizer for the denoising_model
    :param data_loader: FastTensorDataLoaderWithBinaryMask, training DataLoader
    :param epoch: int, current epoch
    :param epochs: int, total number of epochs
    :param device: str, device
    :param scheduler_iter:
    :param scheduler_epoch:
    :param clip_value:
    :param clip_norm:
    :param target_distribution: torch.tensor, target distribution (for imbalanced case)
    :param disable_progress_bar: bool, True in order to disable the progress bar
    :return: Dict with the "loss" value
    """
    denoising_model.train()
    cont_loss_sum = 0.0
    cat_loss_sum = 0.0
    loss_sum = 0.0
    loss_count = 0
    pbar = tqdm(data_loader, desc='EPOCH {}/{} - TRAIN'.format(epoch + 1, epochs), disable=disable_progress_bar)
    batch_size = 1
    for i, batch_tuple in enumerate(pbar):
        batch, mask = batch_tuple
        if i == 0:
            batch_size = batch.shape[0]
        optimizer.zero_grad()
        current_weights = None
        if target_distribution is not None:
            current_weights = 1/target_distribution[batch[:, -1].to(torch.long)]/target_distribution.shape[0]
        cont_loss, cat_loss = diffusion_manager.total_loss(denoising_model,
                                                           batch.to(device, dtype=torch.float32),
                                                           mask=mask.to(device),
                                                           target_weights=current_weights)
        loss = None
        if cont_loss is not None:
            loss = cont_loss
        if cat_loss is not None:
            loss = cat_loss if loss is None else loss + cat_loss
        loss.backward()
        if clip_value is not None:
            torch.nn.utils.clip_grad_value_(denoising_model.parameters(), clip_value)
        if clip_norm is not None:
            torch.nn.utils.clip_grad_norm_(denoising_model.parameters(), clip_norm)
        optimizer.step()
        if scheduler_iter is not None:
            scheduler_iter.step()
        cont_loss_sum += 0 if cont_loss is None else cont_loss.detach().cpu().item() * len(batch)
        cat_loss_sum += 0 if cat_loss is None else cat_loss.detach().cpu().item() * len(batch)
        loss_sum += loss.detach().cpu().item() * len(batch)
        loss_count += len(batch)
        pbar.set_postfix(Datapoint='{}/{}'.format(loss_count, len(data_loader)*batch_size),
                         Loss_Tot='{:.5f}'.format(loss_sum/loss_count),
                         Loss_Cont='{:.5f}'.format(cont_loss_sum/loss_count),
                         Loss_Cat='{:.5f}'.format(cat_loss_sum/loss_count))

    pbar.close()
    if scheduler_epoch is not None:
        scheduler_epoch.step()
    return {'loss': loss_sum/loss_count}


def eval_fn(diffusion_manager: DiffusionManager,
            denoising_model: torch.nn.Module,
            data_loader,
            epoch,
            epochs,
            device,
            target_distribution=None,
            disable_progress_bar=False) -> Dict:
    """Function for 1 epoch test

    :param diffusion_manager: Object managing the diffusion processes (forward - denoising)
    :param denoising_model: torch.nn.Module, the denoising model
    :param data_loader: test/evaluation data loader
    :param epoch: int, current epoch
    :param epochs: int, total number of epochs
    :param device: str, current device
    :param target_distribution: torch.tensor, target_distribution
    :param disable_progress_bar: bool, True in order to disable the progress bar
    :return: Dict containing the text/eval "loss"
    """
    denoising_model.eval()

    with torch.no_grad():
        loss_sum = 0.0
        cont_loss_sum = 0.0
        cat_loss_sum = 0.0
        loss_count = 0
        pbar = tqdm(data_loader, desc='EPOCH {}/{} - TEST'.format(epoch + 1, epochs), disable=disable_progress_bar)
        batch_size = 1
        for i, batch_tuple in enumerate(pbar):
            batch, mask = batch_tuple
            if i == 0:
                batch_size = batch.shape[0]
            current_weights = None
            if target_distribution is not None:
                current_weights = 1/target_distribution[batch[:, -1].to(torch.long)]/target_distribution.shape[0]
            cont_loss, cat_loss = diffusion_manager.total_loss(denoising_model,
                                                               batch.to(device, dtype=torch.float32),
                                                               mask=mask.to(device),
                                                               target_weights=current_weights)
            loss = None
            if cont_loss is not None:
                loss = cont_loss
            if cat_loss is not None:
                loss = cat_loss if loss is None else loss + cat_loss
            loss_sum += loss.detach().cpu().item() * len(batch)
            cont_loss_sum += 0 if cont_loss is None else cont_loss.detach().cpu().item() * len(batch)
            cat_loss_sum += 0 if cat_loss is None else cat_loss.detach().cpu().item() * len(batch)
            loss_count += len(batch)
            pbar.set_postfix(Datapoint='{}/{}'.format(loss_count, len(data_loader) * batch_size),
                             Loss_Tot='{:.5f}'.format(loss_sum / loss_count),
                             Loss_Cont='{:.5f}'.format(cont_loss_sum / loss_count),
                             Loss_Cat='{:.5f}'.format(cat_loss_sum / loss_count))

        pbar.close()
    return {'loss': loss_sum/loss_count}


def train_ddpm(diffusion_manager: DiffusionManager,
               denoising_model: torch.nn.Module,
               optimizer: optim.Optimizer,
               train_dataloader: FastTensorDataLoaderWithBinaryMask,
               test_dataloader: FastTensorDataLoaderWithBinaryMask,
               epochs: int,
               device: str,
               best_denoising_model_path: str,
               target_distribution: torch.tensor = None,
               save_dn_fn_state: bool =True,
               disable_progress_bar: bool=False) -> Dict:
    """Train flow

    :param diffusion_manager: Diffusion manager
    :param denoising_model: torch.nn.Module, the denoising model to train/test
    :param optimizer: optim.optimizer, the denoising model optimizer
    :param train_dataloader: FastTensorDataLoaderWithBinaryMask, Train data loader
    :param test_dataloader: FastTensorDataLoaderWithBinaryMask, Test data loader
    :param epochs: int, epochs number
    :param device: str, device to use
    :param best_denoising_model_path: str, path where the best denoising model has to be saved
    :param target_distribution: torch.tensor, categorical target distribution
    :param save_dn_fn_state: bool, If True, the best model state is saved
    :param disable_progress_bar: bool, If True the progress bar is disabled
    :return: Dict, the best model state
    """
    best_result = float('inf')
    best_model_state = None
    for epoch in range(epochs):
        # Train
        _ = train_one_epoch(diffusion_manager,
                            denoising_model.to(device),
                            optimizer,
                            train_dataloader,
                            epoch,
                            epochs,
                            device,
                            target_distribution=target_distribution,
                            disable_progress_bar=disable_progress_bar)
        # Eval
        eval_dict = eval_fn(diffusion_manager,
                            denoising_model.to(device),
                            test_dataloader,
                            epoch,
                            epochs,
                            device,
                            target_distribution=target_distribution,
                            disable_progress_bar=disable_progress_bar)

        # Checkpoint
        if eval_dict['loss'] < best_result:
            if save_dn_fn_state:
                denoising_model.save_checkpoint_model(metric=eval_dict['loss'],
                                                      epoch=epoch,
                                                      checkpoint_path=best_denoising_model_path)
            best_model_state = denoising_model.to('cpu').state_dict()
            best_result = eval_dict['loss']

        if disable_progress_bar:
            print('EPOCH {}/{}: {:.5f} - best: {:.5f}'.format(epoch+1, epochs, eval_dict['loss'], best_result))
    return best_model_state


def main_train(dn_fn: torch.nn.Module,
               x_train_torch: torch.tensor,
               y_train_torch: torch.tensor,
               x_test_torch: torch.tensor,
               y_test_torch: torch.tensor,
               meta: DatasetMetadata,
               model_params: Dict,
               optimizer_params: Dict,
               train_params: Dict,
               mask_params: Dict,
               save_dn_fn_state: bool=True,
               disable_progress_bar: bool=False) -> Dict:
    """Main train flow

    :param dn_fn: NNET to use as denoising model
    :param x_train_torch: torch.tensor, train features
    :param y_train_torch: torch.tensor, train targets
    :param x_test_torch: torch.tensor, test features
    :param y_test_torch: torch.tensor, test targets
    :param meta: DatasetMetadata (namedTupled)
    :param model_params: Dict, model parameters
    :param optimizer_params: Dict, optimizer parameters
    :param train_params: Dict, train parameters
    :param mask_params: Dict, mask parameters
    :param save_dn_fn_state: bool
    :param disable_progress_bar: bool, if True the progress bar is disabled
    :return: Dict, the best model parameters
    """
    _, ws = np.unique(y_train_torch.cpu().numpy(), return_counts=True)
    distr_weights = torch.tensor(ws / y_train_torch.shape[0]).to(train_params[DEVICE])
    distr_weights = torch.ones_like(distr_weights).to(train_params[DEVICE])

    # Denoising model for diffusion process
    target_classes = [meta.num_classes] if meta.problem_type == CLASSIFICATION else []
    extra_num = 0 if meta.problem_type == CLASSIFICATION else 1
    denoising_model = dn_fn(num_cont=len(meta.continuous_features_idxs) + extra_num,
                            num_classes=meta.categorical_lengths + target_classes,
                            hidden_size=model_params[LATENT_SPACE_SIZE],
                            timesteps=model_params[TIMESTEPS],
                            params=model_params,
                            problem_type=meta.problem_type,
                            with_target=True).to(train_params[DEVICE])
    if model_params[TRAIN_LOAD_FROM_CHECKPOINT]:
        denoising_model.load_checkpoint_model(checkpoint_path=model_params[CHECKPOINT_PATH],
                                              device=train_params[DEVICE])
    optimizer = optim.Adam(denoising_model.parameters(), lr=optimizer_params[LEARNING_RATE])

    # Data Loader
    mask_on_first_feature = (meta.problem_type == CLASSIFICATION)
    mask_on_last_feature = not mask_on_first_feature
    data_train = (torch.cat([x_train_torch, y_train_torch], dim=1) if meta.problem_type == CLASSIFICATION else
                  torch.cat([y_train_torch, x_train_torch], dim=1))

    train_dataloader = FastTensorDataLoaderWithBinaryMask(data_train,
                                                          batch_size=train_params[TRAIN_BATCH_SIZE],
                                                          shuffle=True,
                                                          drop_last=True,
                                                          mask_percentage_row=1.0,
                                                          mask_col_num=mask_params[MASK_COL_NUMBER],
                                                          mask_col_num_constant=mask_params[MASK_COL_NUMBER_CONSTANT],
                                                          mask_on_first_feature=mask_on_first_feature,
                                                          mask_on_last_feature=mask_on_last_feature)

    data_test = (torch.cat([x_test_torch, y_test_torch], dim=1) if meta.problem_type == CLASSIFICATION else
                 torch.cat([y_test_torch, x_test_torch], dim=1))

    test_dataloader = FastTensorDataLoaderWithBinaryMask(data_test,
                                                         batch_size=train_params[TEST_BATCH_SIZE],
                                                         shuffle=True,
                                                         drop_last=True,
                                                         mask_percentage_row=1.0,
                                                         mask_col_num=mask_params[MASK_COL_NUMBER],
                                                         mask_col_num_constant=mask_params[MASK_COL_NUMBER_CONSTANT],
                                                         mask_on_first_feature=mask_on_first_feature,
                                                         mask_on_last_feature=mask_on_last_feature)

    # Diffusion Manager
    dm = DiffusionManager(len(meta.continuous_features_idxs) + extra_num,
                          meta.categorical_lengths + target_classes,
                          PROMPT,
                          timesteps=model_params[TIMESTEPS],
                          problem_type=meta.problem_type,
                          device=train_params[DEVICE])

    print('Start training...')
    denoising_model_state = train_ddpm(diffusion_manager=dm,
                                       denoising_model=denoising_model,
                                       optimizer=optimizer,
                                       train_dataloader=train_dataloader,
                                       test_dataloader=test_dataloader,
                                       epochs=train_params[EPOCHS],
                                       device=train_params[DEVICE],
                                       best_denoising_model_path=model_params[CHECKPOINT_PATH],
                                       target_distribution=distr_weights,
                                       save_dn_fn_state=save_dn_fn_state,
                                       disable_progress_bar=disable_progress_bar)

    return denoising_model_state


def main_test(dn_fn: torch.nn.Module,
              x_train_torch: torch.tensor,
              y_train_torch: torch.tensor,
              x_test_torch: torch.tensor,
              y_test_torch: torch.tensor,
              meta: DatasetMetadata,
              model_params: Dict,
              train_params: Dict,
              svt_exp_params: Dict,
              dataset_name: str,
              model_state=None,
              disable_progress_bar: bool=False):
    """Main test flow

    :param dn_fn: NNET to use as denoising model
    :param x_train_torch: torch.tensor, train features
    :param y_train_torch: torch.tensor, train targets
    :param x_test_torch: torch.tensor, test features
    :param y_test_torch: torch.tensor, test targets
    :param meta: DatasetMetadata
    :param model_params: Dict, model parameters
    :param svt_exp_params: Dict, experiment parameters
    :param train_params: Dict, train parameters
    :param dataset_name: str
    :param model_state:
    :param disable_progress_bar: bool, if True the progress bar is disabled
    :return: None
    """

    # Denoising model
    print('Loading pre-trained denoising model...')
    target_classes = [meta.num_classes] if meta.problem_type == CLASSIFICATION else []
    extra_num = 0 if meta.problem_type == CLASSIFICATION else 1
    sample_denoising_model = dn_fn(num_cont=len(meta.continuous_features_idxs) + extra_num,
                                   num_classes=meta.categorical_lengths + target_classes,
                                   hidden_size=model_params[LATENT_SPACE_SIZE],
                                   timesteps=model_params[TIMESTEPS],
                                   params=model_params,
                                   problem_type=meta.problem_type,
                                   with_target=True).to(train_params[DEVICE])
    if model_state is None:
        sample_denoising_model.load_checkpoint_model(checkpoint_path=model_params[CHECKPOINT_PATH],
                                                     device=train_params[DEVICE])
    else:
        sample_denoising_model.load_state_dict(model_state)
        sample_denoising_model.to(train_params[DEVICE])

    # Diffusion Manager
    print('Diffusion Manager...')
    dm = DiffusionManager(len(meta.continuous_features_idxs) + extra_num,
                          meta.categorical_lengths + target_classes,
                          PROMPT,
                          timesteps=model_params[TIMESTEPS],
                          problem_type=meta.problem_type,
                          device=train_params[DEVICE])

    res_output = []
    check = True
    for i in range(svt_exp_params[MIN_COL_TO_MASK], svt_exp_params[MAX_COL_TO_MASK] + 1):
        print('=========================================')
        print('Number of columns to compute: {}'.format(i))
        tmp_res = ['MLUtilityTest',
                   meta.dataset_name,
                   model_params[CURRENT_MODEL_NAME],
                   'XGBoost',
                   i]

        # Sampling
        with torch.no_grad():
            print('SAMPLING...')
            print('\tSampling chunk size: {}'.format(svt_exp_params[CHUNK_SIZE]))
            mask_on_first_feature = (meta.problem_type == CLASSIFICATION)
            mask_on_last_feature = not mask_on_first_feature
            data_train = (
                torch.cat([x_train_torch, y_train_torch], dim=1) if meta.problem_type == CLASSIFICATION else
                torch.cat([y_train_torch, x_train_torch], dim=1))
            dataloader = FastTensorDataLoaderWithBinaryMask(data_train,
                                                            batch_size=svt_exp_params[CHUNK_SIZE],
                                                            shuffle=True,
                                                            drop_last=False,
                                                            mask_percentage_row=1.0,
                                                            mask_col_num=i,
                                                            mask_col_num_constant=True,
                                                            mask_on_first_feature=mask_on_first_feature,
                                                            mask_on_last_feature=mask_on_last_feature)

            res, original, mask = dm.sample(sample_denoising_model,
                                            dataloader,
                                            disable_progress_bar)
        print('\tSample shape: {}'.format(res.shape))

        # ML Test
        print('ML Utility Test...')
        if meta.problem_type == CLASSIFICATION:
            fake_x = res[:, :-1]
            fake_y = res[:, -1]
        else:
            fake_x = res[:, 1:]
            fake_y = res[:, 0]
        (rf_x_train, rf_y_train,
         other_x, other_y) = data_to_supervised_model_format(x_train_torch,
                                                             y_train_torch,
                                                             [fake_x, x_test_torch],
                                                             [fake_y.unsqueeze(dim=1), y_test_torch],
                                                             len(meta.continuous_features_idxs),
                                                             problem_type=meta.problem_type)

        if check:
            # Train Random Forest
            print('Start training XGBoost on real training data...')
            tester = train_xgboost(rf_x_train, rf_y_train, dataset_name)

            # Test model on original data
            print('Start testing XGBoost on real testing data...')
            metric = test_ml(tester, other_x[1], other_y[1], meta.problem_type)
            if meta.problem_type == CLASSIFICATION:
                print('TRAIN ON REAL DATA --> ACC. ON REAL DATA: {:.3f}'.format(100 * metric[ACCURACY]))
                print('TRAIN ON REAL DATA --> MACRO F1 ON REAL DATA: {:.3f}'.format(100 * metric[MACRO_F1]))
                print('TRAIN ON REAL DATA --> WEIGHTED F1 ON REAL DATA: {:.3f}'.format(100 * metric[WEIGHTED_F1]))
            else:
                print('TRAIN ON REAL DATA --> MSE ON REAL DATA: {:.5f}'.format(metric[MSE]))
            check = False

        # Train model on masked data
        print('Start training XGBoost on masked training data...')
        tester = train_xgboost(other_x[0], other_y[0], dataset_name)

        # Test model on original data
        print('Start testing Masked XGBoost on original testing data...')
        ori_test_metric = test_ml(tester, other_x[1], other_y[1], meta.problem_type)
        if meta.problem_type == CLASSIFICATION:
            print('TRAIN ON SIMULATED DATA --> ACC. ON REAL DATA: {:.3f}'.format(100 * ori_test_metric[ACCURACY]))
            print('TRAIN ON SIMULATED DATA --> MACRO F1 ON REAL DATA: {:.3f}'.format(100 * ori_test_metric[MACRO_F1]))
            print('TRAIN ON SIMULATED DATA --> WEIGHTED F1 ON REAL DATA: {:.3f}'.format(100 * ori_test_metric[WEIGHTED_F1]))
            tmp_res.extend([-1.0,
                            ori_test_metric[ACCURACY],
                            ori_test_metric[MACRO_F1],
                            ori_test_metric[WEIGHTED_F1],
                            -1.0,
                            -1.0,
                            -1.0,
                            -1.0])
        else:
            print('TRAIN ON SIMULATED DATA --> MSE ON REAL DATA: {:.5f}'.format(ori_test_metric[MSE]))
            tmp_res.extend([ori_test_metric[MSE],
                            -1.0,
                            -1.0,
                            -1.0,
                            -1.0,
                            -1.0,
                            -1.0,
                            -1.0])
        res_output.append(tmp_res)
    df = pd.DataFrame(res_output, columns=experiment_cols)

    if ((svt_exp_params[ML_UTILITY_OUT_PATH] is not None) and
            (svt_exp_params[ML_UTILITY_OUT_PATH] != '')):
        df.to_csv(svt_exp_params[ML_UTILITY_OUT_PATH][:-4] + '_ML.csv', sep='|')
    return df
