import os,sys
import itertools
from hydra.utils import get_original_cwd

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import sklearn.exceptions

import torch
from lightning.pytorch.loggers import WandbLogger

import wandb

#from densratio import densratio

from lib.da.configs.data_model_configs import get_dataset_class
from lib.da.configs.hparams_to_search import get_hparams_class
from lib.da.misc.utils import starting_logs, log_debug, _calc_metrics

from lib.dre.common.estimate import gen_estimate_density_rate_func
from lib.utils import set_seed_everything
from lib.dre.train.linear import dre_train_for_all_epoch
from lib.da.algorithms import get_algorithm_class
from lib.da.dataloader.dataloader import data_generator as ts_data_generator
from lib.da.dataloader.mdn_dataloader import mdn_data_generator
from lib.da.dataloader.ar_dataloader import ar_data_generator
from lib.da.dataloader.tm_dataloader import tm_data_generator

import warnings
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)

torch.backends.cudnn.benchmark = True  # to fasten TCN

all_n_feats_available = [8, 16, 32, 64, 128]

class cross_domain_trainer(object):
    """
   This class contain the main training functions for our AdAtime
    """
    def __init__(self, args):
        self.clf_algo = args.clf_algo  # Selected  DA Method
        self.dre_method = args.dre_method
        self.dataset = args.dataset  # Selected  Dataset
        #self.backbone = args.backbone
        self.device_id = args.device_id
        self.device = torch.device(f'cuda:{args.device_id}')  # device
        #self.num_sweeps = args.num_sweeps
        self.seed_original = args.seed
        self.redo_dre = args.dre.redo
        self.num_feats = args.num_feats
        if not self.num_feats in all_n_feats_available:
            msg = f'Use num_feats within {all_n_feats_available}: the current value is {self.num_feats}.'
            sys.exit(msg)
        self.n_cols_to_use = args.n_cols_to_use[str(self.num_feats)]
        #self.redo_feature_extract = args.feature_extract.redo
        self.config = args
        
        # setting for importance weighing modeling
        self.hparams_train_dre = {**args.dre.nn,
                                  **args.dre.hparams.common,
                                  **args.dre.hparams[self.dre_method]}
        # self.dre_hidden_dim = args.dre.hidden_dim
        # self.dre_n_layers_per_block = args.dre.n_layers_per_block
        # self.dre_n_blocks = args.dre.n_blocks
        # self.dre_dropout = args.dre.dropout
        # self.dre_n_epochs = args.dre.n_epochs
        # self.dre_bachsize = args.dre.bachsize
        # self.dre_early_stoppping_partience = args.dre.early_stoppping_partience
        # self.dre_hparams = args.dre.hparams
        #self.dim_features = args.feature_extract.dim_features

        self.hparams_train_dre['hidden_dim'] = self.num_feats*4
        # Exp Description
        self.experiment_description = args.experiment_name
        self.run_description = args.run_description

        # paths
        #self.home_path = os.getcwd()
        self.out_top_dir = os.path.join(
                                        get_original_cwd(),
                                        'out',
                                        self.experiment_description,
                                        self.run_description,
                                        )
        os.makedirs(self.out_top_dir, exist_ok=True)

        self.data_path = os.path.join(get_original_cwd(), args.data_path, self.dataset)

        # Specify runs
        self.num_runs = args.num_runs

        # get dataset and base model configs
        self.dataset_configs, self.hparams_class = self.get_configs()
        self.dataset_configs.device = args.device_id
        self.dataset_configs.debug = args.debug
        self.dataset_configs.used_clf_algo = args.clf_algo
        self.test_data_ratio = args.test_data_ratio

        # self.param_training_feature_extract = self.hparams_class.feature_extract_hparams
        # self.param_training_feature_extract['backbone'] = self.backbone
        # self.param_training_feature_extract['dataset_configs'] = deepcopy(self.dataset_configs)

        # # test 
        # self.dataset_configs.final_out_channels = args.feature_extract.dim_features
        # self.param_training_feature_extract['batch_size'] = args.feature_extract.batch_size
        # self.param_training_feature_extract['early_stoppping_partience'] = args.feature_extract.early_stoppping_partience
        # self.param_training_feature_extract['balance_sample_sizes'] = args.feature_extract.balance_sample_sizes
        # set classifier's param seange to fit
       

        self.clf_params_to_search = {**self.hparams_class.clf_params_to_search[self.clf_algo]}
        product = [x for x in itertools.product(*self.clf_params_to_search.values())]
        self.clf_params = [dict(zip(self.clf_params_to_search.keys(), r)) for r in product]

        self.result_df_list = []

        # Set the number of threads to fit classifier
        self.num_threads = args.num_threads

        self.exp_log_dir = os.path.join(self.out_top_dir, 'log')
        os.makedirs(self.exp_log_dir, exist_ok=True)
        #copy_Files(self.exp_log_dir)  # save a copy of training files:
        self.dre_results_dir =  os.path.join(self.out_top_dir, self.dre_method)
        os.makedirs(self.dre_results_dir, exist_ok=True)
        self.result_dir =  os.path.join(self.dre_results_dir,  self.clf_algo)
        os.makedirs(self.result_dir, exist_ok=True)

            
    def train(self):
        self.da_method = f'dre-{self.dre_method}_clf-{self.clf_algo}'
        scenarios = self.dataset_configs.scenarios  # return the scenarios given a specific dataset.

        for i in scenarios:           
            src_id = i[0]
            trg_id = i[1]
            ds_name = f"{src_id}_src-{trg_id}_tgt"     
            for run_id in range(self.num_runs):  # specify number of consecutive runs
                self.seed_int = self.seed_original + run_id
                self.seed_str = str(self.seed_int)
                out_resultname_suffix = f'{ds_name}_seed-{self.seed_int}'

                # Load data
                self.load_data(src_id, trg_id)
                # all souce data
                all_srs_X = self.src_data['X']
                all_srs_Y = self.src_data['Y']

                # all tareget data
                all_trg_X = self.trg_data['X']
                all_trg_Y = self.trg_data['Y']

                # create results directory
                #self.out_dir = os.path.join(self.out_top_dir, runname)
                #os.makedirs(self.out_dir, exist_ok=True)
                wandb_log_name = '_'.join([ds_name, self.dre_method, 
                                           self.experiment_description,
                                           self.run_description, f'seed-{self.seed_int}'])
                wandb_logger = WandbLogger(project=wandb_log_name)

                # Logging
                self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir,
                                                                    src_id, trg_id, self.seed_str)

                # fixing random seed
                set_seed_everything(self.seed_int)

                # creat features by random projection
                projction_map_orig_to_feats = np.random.normal(
                    size=self.num_feats* self.n_cols_to_use).reshape( self.n_cols_to_use, self.num_feats)
                srs_X_to_model = all_srs_X[:, 0:self.n_cols_to_use]
                trg_X_to_model = all_trg_X[:, 0:self.n_cols_to_use]
                srs_Y_to_model = all_srs_Y
                trg_Y_to_model = all_trg_Y 
                all_src_feat_np = np.matmul(srs_X_to_model, projction_map_orig_to_feats)
                all_trg_feat_np = np.matmul(trg_X_to_model, projction_map_orig_to_feats)

                # data for estimating importance weight
                nrow_all_train_X = all_src_feat_np.shape[0]
                nrow_all_test_X = all_trg_feat_np.shape[0]
                min_nrow_feat = min(nrow_all_train_X, nrow_all_test_X)
                src_feat_all_dre = all_src_feat_np[0:min_nrow_feat, :]
                trg_feat_all_dre = all_trg_feat_np[0:min_nrow_feat, :]
                train_src_feat_dre, validate_src_feat_dre,\
                    train_trg_feat_dre, validate_trg_feat_dre,\
                        = train_test_split(
                                        src_feat_all_dre,
                                        trg_feat_all_dre, 
                                        test_size = 0.2,
                                        shuffle=True,
                                        random_state=self.seed_int)

                # Esitmating importance weight (run DRE)
                params_training_dre = self.hparams_train_dre.copy()
                params_training_dre['DoBatchNormarize'] = False
                importance_weight_file = os.path.join(self.dre_results_dir, f'importance_weihgt_{out_resultname_suffix}.npz')
                if self.redo_dre or not os.path.isfile(importance_weight_file):
                    if self.dre_method in ['alphaDiv', 'nnBD_LSIF', 'LSIF']:
                        params_training_dre = self.hparams_train_dre.copy()
                        params_training_dre['DoBatchNormarize'] = False
                        model_trained, res_estimation = dre_train_for_all_epoch(
                            train_src_feat_dre,
                            train_trg_feat_dre,
                            validate_src_feat_dre,
                            validate_trg_feat_dre,
                            validate_src_feat_dre,
                            validate_trg_feat_dre,
                            params_training_dre,
                            method=self.dre_method,
                            logger=wandb_logger,
                            out_results_dir=self.exp_log_dir,
                            do_save_result=False,
                            seed=self.seed_int,
                            device_id=self.device_id ,
                        )
                        log_debug(self.logger,
                            self.dataset, self.da_method, self.exp_log_dir,
                            src_id, trg_id, self.seed_str, res_estimation)

                        # importance weight esitmation
                        esti_density_rate_func = gen_estimate_density_rate_func(self.dre_method)
                        all_src_feat_tsr = torch.from_numpy(
                            all_src_feat_np.astype(np.float32)).to(self.device)
                        importance_weihgt_tsr = esti_density_rate_func(all_src_feat_tsr, model_trained)
                        importance_weihgt = importance_weihgt_tsr.cpu().detach().numpy() 
                        np.savez(importance_weight_file, importance_weihgt)
                    # if self.dre_method  ==  'RuLSIF':
                    #     pass
                    #     densratio_obj = densratio(all_trg_feat_np, all_src_feat_np)
                    #     importance_weihgt = densratio_obj.compute_density_ratio(all_src_feat_np)
                    #     np.savez(importance_weight_file, importance_weihgt)
                else:
                    importance_weihgt = np.load(importance_weight_file)['arr_0']
            
                sum_importance_weihgt = np.sum(importance_weihgt)
                if sum_importance_weihgt < 1E-10:
                    msg = {'Warning ': f'All values of importance_weihgt are 0.0! ${ds_name} and seed=${self.seed_str} are skiped.'}
                    log_debug(self.logger,
                            self.dataset, self.da_method, self.exp_log_dir,
                            src_id, trg_id, self.seed_str, msg)
                    continue
                
                # train data (form the source domain)
                train_X, validate_X, \
                    train_Y, validate_Y,\
                    train_importance_weihgt, validate_importancc_weight = train_test_split(
                                    srs_X_to_model, 
                                    srs_Y_to_model,
                                    importance_weihgt,
                                    test_size = self.test_data_ratio, 
                                    shuffle=True, random_state=self.seed_int)
                
                # test data (from the target domain)
                test_X = trg_X_to_model
                test_Y = trg_Y_to_model

                # get classifier algorithm  
                clf_algorithm_class = get_algorithm_class(self.clf_algo)
                clf = clf_algorithm_class(self.dataset_configs,
                                          train_X, train_Y, 
                                          validate_X, validate_Y,
                                          train_importance_weihgt, validate_importancc_weight,
                                          self.seed_int, self.num_threads)
                
                param_name_to_params_dict = {}
                best_parname = None
                best_parname_non_weihgt = None
                src_true_labels = validate_Y
                for pars in self.clf_params:
                    param_name = '-'.join([f'{key}_{pars[key]}' for key in pars.keys()])
                    param_name_to_params_dict[param_name] = pars

                    if best_parname is None:
                        best_parname = param_name
                        best_res_roc_auc = 0.5
                    if best_parname_non_weihgt is None:
                        best_parname_non_weihgt = param_name
                        best_res_roc_auc_non_weihgt = 0.5
        
                    clf.fit(pars)
                    src_pred_score, src_pred_score_non_wieghts = clf.predict(validate_X)

                    index_to_save_result = {
                                    'dataset': [ds_name],
                                    'n_feats': [self.num_feats],
                                    'dre_method': [self.dre_method],
                                    'clf_algorithm': [self.clf_algo],
                                    'clf_params': [param_name],
                                    'seed': [self.seed_str]}
                    metrics = self.calc_results_per_run(index_to_save_result,
                                                       src_pred_score, src_true_labels,
                                                       validate_type='IW-validate(soucre)',
                                                       sample_weight=validate_importancc_weight)
                    metrics_non_weihgt = self.calc_results_per_run(index_to_save_result,
                                                       src_pred_score_non_wieghts, src_true_labels,
                                                       validate_type='NonWeight-validate(soucre)') 
                    res_roc_auc = metrics['roc_auc']
                    res_roc_auc_non_weihgt = metrics_non_weihgt['roc_auc']

                    # Logging
                    res = pars.copy()
                    res.update(metrics)
                    log_debug(self.logger,
                        self.dataset, self.da_method, self.exp_log_dir,
                        src_id, trg_id, self.seed_str, res)
                    
                    if res_roc_auc > best_res_roc_auc:
                        best_parname = param_name
                        best_res_roc_auc = res_roc_auc
                    if res_roc_auc_non_weihgt > best_res_roc_auc_non_weihgt:
                        best_parname_non_weihgt = param_name
                        best_res_roc_auc_non_weihgt = res_roc_auc_non_weihgt

                best_par = param_name_to_params_dict[best_parname]
                clf.fit_all_data(best_par)
                trg_pred_score, _ = clf.predict(test_X)
                trg_true_labels = test_Y
                test_metrics = self.calc_results_per_run(index_to_save_result,
                                                        trg_pred_score, trg_true_labels,
                                                        validate_type='IW-test(target)')
                test_res = best_par.copy()
                test_res.update(test_metrics)
                log_debug(self.logger,
                    self.dataset, self.da_method, self.exp_log_dir,
                    src_id, trg_id, self.seed_str, test_res)
                
                best_par_non_weihgt = param_name_to_params_dict[best_parname_non_weihgt]
                clf.fit_all_data(best_par_non_weihgt)
                _, trg_pred_score_non_weihgt = clf.predict(test_X)
                test_metrics_non_weihgt = self.calc_results_per_run(index_to_save_result,
                                                        trg_pred_score_non_weihgt, trg_true_labels,
                                                        validate_type='NonWeight-test(target)')
                test_res_weihgt = best_par.copy()
                test_res_weihgt.update(test_metrics_non_weihgt)
                log_debug(self.logger,
                    self.dataset, self.da_method, self.exp_log_dir,
                    src_id, trg_id, self.seed_str, test_res_weihgt)
                          
        self.summarize_and_save_results(self.result_dir,
                                        levels_to_summarize_all=[
                                        'validate_type',
                                        'n_feats',
                                        'dre_method',
                                        'clf_algorithm'],
                                        levels_to_summarize_each_senario=[
                                        'validate_type',
                                        'dataset',
                                        'n_feats',
                                        'dre_method',
                                        'clf_algorithm'])
        # train() ends.
        return

    def get_configs(self):
        dataset_class = get_dataset_class(self.dataset)
        hparams_class = get_hparams_class(self.dataset)
        return dataset_class(), hparams_class()

    def load_data(self, src_id, trg_id):
        if self.dataset == 'MINI_DOMAIN_NET':
            data_generator = mdn_data_generator
        elif self.dataset == 'AMAZON_REVIEWS':
            data_generator = ar_data_generator
        elif self.dataset == 'TRANSFORMED_MOONS':
            data_generator = tm_data_generator
        elif self.dataset == 'EEG' or self.dataset == 'WISDM' or self.dataset == 'HAR' or self.dataset == 'HHAR_SA':
            data_generator = ts_data_generator
        
        self.src_data = data_generator(
            self.data_path, src_id, self.dataset_configs, self.seed_original)
        self.trg_data = data_generator(
            self.data_path, trg_id, self.dataset_configs, self.seed_original)

    def calc_results_per_run(self, 
                             resutl_id_dict, 
                             pred_score, true_labels, validate_type,
                             sample_weight=None):
        '''
        Calculates the acc, f1 and risk values for each cross-domain scenario
        '''
        if sample_weight is None:
            roc_auc = roc_auc_score(true_labels, pred_score)
            acc, f1 = _calc_metrics(pred_score, true_labels, 
                                    self.dataset_configs.class_names)
        else:
            roc_auc = roc_auc_score(true_labels, pred_score,
                                    sample_weight=sample_weight)
            acc, f1 = _calc_metrics(pred_score, true_labels, 
                                    self.dataset_configs.class_names,
                                    sample_weight=sample_weight)

        run_metrics = {
            'accuracy': acc,
            'f1': f1,
            'roc_auc': roc_auc}

        reuslts_dict = {'validate_type': validate_type}
        reuslts_dict.update(resutl_id_dict)
        reuslts_dict.update(run_metrics)
        df =pd.DataFrame.from_dict(reuslts_dict)
        idx = ['validate_type'] + list(resutl_id_dict.keys())
        df.set_index(idx, inplace=True)
        self.result_df_list.append(df)

        if self.config.use_wandb:
            log_run_metrics = resutl_id_dict.copy()
            log_run_metrics.update(run_metrics)
            wandb.log(log_run_metrics) 

        return run_metrics

  
    def summarize_and_save_results(self, out_dir, 
                                   levels_to_summarize_all,
                                   levels_to_summarize_each_senario):
        results = pd.concat(self.result_df_list, axis=0) 
        result_summary_all = results.groupby(
            level=levels_to_summarize_all).describe()      
        result_summary_each_senario = results.groupby(
            level=levels_to_summarize_each_senario).describe()
        self.results_df = results.reset_index(drop=False)
        self.result_summary_all = result_summary_all.reset_index(drop=False)
        self.result_summary_df_each_senario = result_summary_each_senario.reset_index(drop=False)
        all_results_save_path = os.path.join(out_dir, "all_results.csv")
        results.to_csv(all_results_save_path)
        report_save_summary_each_senario_path = os.path.join(out_dir, f"summary_each_senario.csv")
        result_summary_each_senario.to_csv(report_save_summary_each_senario_path)
        result_summary_all_path = os.path.join(out_dir, f"summary_all.csv")
        result_summary_all.to_csv(result_summary_all_path)
        return

       
