from abc import ABC, abstractmethod
from omegaconf import OmegaConf
import logging
from data.data_preprocess import DataProcessor
import torch
import os
from pathlib import Path
import pickle
from datetime import datetime
import time
import pandas as pd

from evaluation.eval_detection import DetectionScore
from evaluation.eval_similarity import SimilarityScores
from evaluation.eval_ml_efficiency import MLEScore # this leads to pkg_resources warning due to outdated dython dependencies
from evaluation.eval_dcr import DCRScore
from evaluation.eval_alphaprecision import AlphaPrecision
from evaluation.eval_missings import MissingEvaluator
from tqdm import tqdm


class Experiment(ABC):

    def __init__(self, config, args):
        
        self.seed = args.seed
        self.config = config
        self.common_config = OmegaConf.load('experiments/configs/common_config.yaml')
        self.device = self.common_config.device
        
        # define paths
        if args.exp_path is None:
            exp_path = datetime.now().strftime('%Y_%m_%d_%H-%M-%S')
        else:
            exp_path = args.exp_path
        self.workdir = Path('results') / args.dataset / exp_path
        os.makedirs(self.workdir / 'figures', exist_ok=True)

        logging.warning(f'=== Initializing {args.dataset} dataset ===')
        self.dataset = args.dataset
        self.data_processor = DataProcessor(self.dataset,
                                           self.config.data.cat_encoding,
                                           self.seed,
                                           self.common_config.val_prop,
                                           self.common_config.test_prop,
                                           cat_min_freq=self.common_config.cat_min_freq,
                                           missing_mechanism=args.miss_mechanism,
                                           seed_missings=self.seed,
                                           p_miss=self.common_config.p_miss,
                                           p_obs=self.common_config.p_obs,
                                           )
    
        # sample as many observations as there are in the real train set if < 100k
        df_trn, _, _ = self.data_processor.get_data_splits()
        if df_trn.height <= 100_000:
            self.num_samples = df_trn.height
        else:
            self.num_samples = 100_000

        print(f"CUDA available? {torch.cuda.is_available()}")
        
        if self.common_config.use_tf32 and torch.cuda.is_available():
            torch.set_float32_matmul_precision('high')
        
               
    @abstractmethod
    def train(self, **kwargs):
        ...
        
    @abstractmethod
    def sample(self, num_samples, **kwargs):
        ...
        
    @abstractmethod
    def save_model(self):
        ...
          
    @abstractmethod
    def load_model(self):
        ...
        
        
    def evaluate(self):
        
        logging.warning('=== Loading generative model... ===')
        self.load_model()
        
        logging.warning('=== Benchmarking sampling speed... ===')
        sample_start_time = time.monotonic()
        X_cat_gen, X_num_gen = self.sample(1000, seed=42)
        sample_duration = time.monotonic() - sample_start_time
        with open(self.workdir / 'sample_time.pkl', 'wb') as f:
            pickle.dump(sample_duration, f)
            
        logging.warning('=== Evaluating synthetic data... ===')
        
        # init evaluators and ground truth data
        df_trn, _, df_tst = self.data_processor.get_data_splits()
        os.makedirs(self.workdir / 'samples', exist_ok=True)
        
        def no_miss(df):
            """Remove rows with missing values."""
            return df.drop_nulls()
        
        detect_score = DetectionScore(self.data_processor.cat_cols, self.data_processor.num_cols)
        sim_score = SimilarityScores(no_miss(df_trn), no_miss(df_tst), self.data_processor.cat_cols)
        dcr_score = DCRScore(df_trn, df_tst, self.data_processor.cat_cols, self.data_processor.num_cols)
        mle_score = MLEScore(self.data_processor.cat_cols, self.data_processor.num_cols, self.data_processor.target)
        alphaprec = AlphaPrecision(self.data_processor.cat_cols)
        if self.data_processor.missing_mechanism is not None:
            miss_score = MissingEvaluator(df_trn, self.data_processor.num_cols, self.data_processor.cat_cols)
            
        eval_results = []
        includes_miss_ind = self.config.model_name != 'tabcascade'
        
        for i in tqdm(range(self.common_config.eval_sample_iter)):
            results = {}
            seed = (42 + (i-1)) * i
            X_cat_gen, X_num_gen = self.sample(self.num_samples, seed=seed)
            df_gen = self.data_processor.postprocess(X_cat_gen, X_num_gen,
                                                     includes_miss_ind=includes_miss_ind)
            df_gen.write_parquet(self.workdir/'samples'/f'gen_data_{seed}.parquet')
        
            # evaluate
            results.update(detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5))
            results.update(sim_score.compute_similarity(no_miss(df_trn), no_miss(df_gen)))
            shape_trend = sim_score.compute_colwise_density_metrics(df_trn, df_gen)
            shape_trend = pd.json_normalize(shape_trend, sep='_').iloc[0].to_dict() # concat keys
            results.update(shape_trend)
            results.update(sim_score.compute_diff_in_corr(no_miss(df_gen)))

            dcr_res = dcr_score.compute_dcr(df_gen, seed=seed)
            del dcr_res['dcr_raw']
            results.update(dcr_res)
            
            results.update(mle_score.get_score(df_trn, df_tst, df_gen, seed=seed))
            results.update(alphaprec.estimate_scores(df_trn, df_gen))
        
            # evaluate missingness
            if self.data_processor.missing_mechanism is not None:
                results.update(miss_score.eval_cond_dist(df_trn, df_tst, df_gen, seed=seed))
                results.update(miss_score.eval_correlation(df_trn, df_gen))
                results.update(miss_score.eval_similarity(df_trn, df_gen))
                
            eval_results.append(results)
            
        # aggregate results
        eval_results = {k: [d[k] for d in eval_results] for k in eval_results[0]}
        with open(self.workdir / 'eval_results.pkl', 'wb') as f:
            pickle.dump(eval_results, f)
        logging.info('=== Evaluation finished, results saved! ===')  

        
    def evaluate_motivation(self):
        
        logging.warning('=== Loading generative model... ===')
        self.load_model()
        
        eval_results = []
        detect_score = DetectionScore(self.data_processor.cat_cols, self.data_processor.num_cols)
        df_trn, _, df_tst = self.data_processor.get_data_splits()
        
        for i in tqdm(range(self.common_config.eval_sample_iter)):
            results = {}
            seed = (42 + (i-1)) * i
            X_cat_gen, X_num_gen = self.sample(self.num_samples, seed=seed)
            df_gen = self.data_processor.postprocess(X_cat_gen, X_num_gen)
               
            results['detection_all'] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop='none')['detection_score']
            results['detection_cat'] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop='num')['detection_score']
            results['detection_num'] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop='cat')['detection_score']
            
            eval_results.append(results)
            
        # aggregate results
        eval_results = {k: [d[k] for d in eval_results] for k in eval_results[0]}
        with open(self.workdir / 'motivation_results.pkl', 'wb') as f:
            pickle.dump(eval_results, f)
            
        
    def save_train_time(self, duration):
        """Save training time in minutes."""
        with open(os.path.join(self.workdir, 'train_time.pkl'), 'wb') as f:
            pickle.dump(duration/60, f)