from copy import deepcopy
import numpy as np
import torch
import pandas as pd
# Metrics
from eval.mle.mle import get_evaluator
from eval.visualize_density import plot_density
from sdmetrics.reports.single_table import QualityReport, DiagnosticReport
from sdmetrics.single_table import LogisticDetection
from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm

# synthcity
from synthcity.metrics import eval_statistical
from synthcity.plugins.core.dataloader import GenericDataLoader
pd.options.mode.chained_assignment = None


class TabMetrics(object):
    def __init__(self, real_data_path, test_data_path, val_data_path, info, device, metric_list) -> None:
        self.real_data_path = real_data_path
        self.test_data_path = test_data_path
        self.val_data_path = val_data_path
        self.info = info
        self.device = device
        self.real_data_size = len(pd.read_csv(real_data_path))
        print(self.real_data_size)
        self.metric_list = metric_list

    def evaluate(self, syn_data):
        metrics, extras = {}, {}
        syn_data_cp = deepcopy(syn_data)
        for metric in self.metric_list:
            func = eval(f"self.evaluate_{metric}")
            print(f"Evaluating {metric}")
            out_metrics, out_extras = func(syn_data_cp)
            if out_metrics is not None:
                metrics.update(out_metrics)
            if out_extras is not None:
                extras.update(out_extras)
        return metrics, extras
    
    def evaluate_density(self, syn_data):
        real_data = pd.read_csv(self.real_data_path)
        real_data.columns = range(len(real_data.columns))
        syn_data.columns = range(len(syn_data.columns))
        

        info = deepcopy(self.info)
        
        y_only = len(syn_data.columns)==1
        if y_only:
            target_col_idx = info['target_col_idx'][0]
            syn_data = self.complete_y_only_data(syn_data, real_data, target_col_idx)

        metadata = info['metadata']
        metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()} # ensure that keys are all integers?

        new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info)

        qual_report = QualityReport()
        qual_report.generate(new_real_data, new_syn_data, metadata)

        diag_report = DiagnosticReport()
        diag_report.generate(new_real_data, new_syn_data, metadata)

        quality =  qual_report.get_properties()
        diag = diag_report.get_properties()

        Shape = quality['Score'][0]
        Trend = quality['Score'][1]

        Overall = (Shape + Trend) / 2

        shape_details = qual_report.get_details(property_name='Column Shapes')
        trend_details = qual_report.get_details(property_name='Column Pair Trends')

        if y_only:
            Shape = shape_details['Score'].min()
        out_metrics = {
            "quality/Shape": Shape,
            "quality/Trend": Trend,
            "quality/Overall": Overall,
        }
        out_extras = {
            "shapes": shape_details,
            "trends": trend_details
        }
        return out_metrics, out_extras
    
    def evaluate_mle(self, syn_data):
        info = deepcopy(self.info)
        num_col_idx = list(info.get('num_col_idx', []))
        # Coerce numeric columns to float so MLE feat_transform does not fail on stray strings
        def _coerce_numeric(df):
            df = df.copy()
            if hasattr(df, 'columns') and hasattr(df, 'iloc'):
                for i in num_col_idx:
                    if i < df.shape[1]:
                        df.iloc[:, i] = pd.to_numeric(df.iloc[:, i], errors='coerce').fillna(0)
            return df
        train_df = _coerce_numeric(syn_data)
        train = train_df.to_numpy()
        test_df = pd.read_csv(self.test_data_path)
        test_df.columns = range(len(test_df.columns))
        test = _coerce_numeric(test_df).to_numpy()
        val = None
        if self.val_data_path:
            val_df = pd.read_csv(self.val_data_path)
            val_df.columns = range(len(val_df.columns))
            val = _coerce_numeric(val_df).to_numpy()
            if val.shape[1] != train.shape[1]:
                # val.csv schema must match train (e.g. 15 cols for adult); else MLE splits from train
                val = None

        task_type = info['task_type']
        # MLE evaluator expects 'binclass' or 'multiclass', not 'classification'
        if task_type == 'classification':
            task_type = 'binclass'
            info['task_type'] = task_type

        evaluator = get_evaluator(task_type)

        if task_type == 'regression':
            best_r2_scores, best_rmse_scores = evaluator(train, test, info, val=val)
            
            overall_scores = {}
            for score_name in ['best_r2_scores', 'best_rmse_scores']:
                overall_scores[score_name] = {}
                
                scores = eval(score_name)
                for method in scores:
                    name = method['name']  
                    method.pop('name')
                    overall_scores[score_name][name] = method 

        else:
            try:
                best_f1_scores, best_weighted_scores, best_auroc_scores, best_acc_scores, best_avg_scores = evaluator(train, test, info, val=val)
            except Exception as e:
                print(f"MLE evaluator failed: {e}")
                best_f1_scores = []
                best_weighted_scores = []
                best_auroc_scores = []
                best_acc_scores = []
                best_avg_scores = [] 

            overall_scores = {}
            for score_name in ['best_f1_scores', 'best_weighted_scores', 'best_auroc_scores', 'best_acc_scores', 'best_avg_scores']:
                overall_scores[score_name] = {}
                
                scores = eval(score_name)
                for method in scores:
                    name = method['name']  
                    method.pop('name')
                    overall_scores[score_name][name] = method
                    
        #mle_score = overall_scores['best_rmse_scores']['XGBRegressor']['RMSE'] if task_type == 'regression' else overall_scores['best_auroc_scores']['XGBClassifier']['roc_auc']
        try:
            if task_type == 'regression':
                mle_score = overall_scores.get('best_rmse_scores', {}).get('XGBRegressor', {}).get('RMSE')
                if mle_score is None and overall_scores.get('best_rmse_scores'):
                    for method, scores in overall_scores['best_rmse_scores'].items():
                        if 'RMSE' in scores:
                            mle_score = scores['RMSE']
                            break
            else:
                mle_score = overall_scores.get('best_auroc_scores', {}).get('XGBClassifier', {}).get('roc_auc')
                if mle_score is None and overall_scores.get('best_auroc_scores'):
                    for method, scores in overall_scores['best_auroc_scores'].items():
                        if 'roc_auc' in scores:
                            mle_score = scores['roc_auc']
                            break
            if mle_score is None:
                print("error in MLE compute: no score found")
        except Exception as e:
            mle_score = None
            print(f"error in MLE compute: {e}")
        
        out_metrics = {
            "mle": mle_score,
        }
        out_extras = {
            "mle": overall_scores,
        }
        return out_metrics, out_extras
    
    def evaluate_c2st(self, syn_data):
        info = deepcopy(self.info)
        real_data = pd.read_csv(self.real_data_path)

        real_data.columns = range(len(real_data.columns))
        syn_data.columns = range(len(syn_data.columns))

        metadata = info['metadata']
        metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()}

        new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info)

        score = LogisticDetection.compute(
            real_data=new_real_data,
            synthetic_data=new_syn_data,
            metadata=metadata
        )
        
        out_metrics = {
            "c2st": score,
        }
        out_extras = {}
        return out_metrics, out_extras

    def evaluate_dcr(self, syn_data):
        info = deepcopy(self.info)
        real_data = pd.read_csv(self.real_data_path)
        test_data = pd.read_csv(self.test_data_path)
        
        num_col_idx = info['num_col_idx']
        cat_col_idx = info['cat_col_idx']
        target_col_idx = info['target_col_idx']

        task_type = info['task_type']
        if task_type == 'regression':
            num_col_idx += target_col_idx
        else:
            cat_col_idx += target_col_idx

        num_ranges = []

        real_data.columns = list(np.arange(len(real_data.columns)))
        syn_data.columns = list(np.arange(len(real_data.columns)))
        test_data.columns = list(np.arange(len(real_data.columns)))
        for i in num_col_idx:
            num_ranges.append(real_data[i].max() - real_data[i].min()) 
        
        num_ranges = np.array(num_ranges)


        num_real_data = real_data[num_col_idx]
        cat_real_data = real_data[cat_col_idx]
        num_syn_data = syn_data[num_col_idx]
        cat_syn_data = syn_data[cat_col_idx]
        num_test_data = test_data[num_col_idx]
        cat_test_data = test_data[cat_col_idx]

        num_real_data_np = num_real_data.to_numpy()
        cat_real_data_np = cat_real_data.to_numpy().astype('str')
        num_syn_data_np = num_syn_data.to_numpy()
        cat_syn_data_np = cat_syn_data.to_numpy().astype('str')
        num_test_data_np = num_test_data.to_numpy()
        cat_test_data_np = cat_test_data.to_numpy().astype('str')

        encoder = OneHotEncoder()
        cat_complete_data_np = np.concatenate([cat_real_data_np, cat_test_data_np], axis=0)
        encoder.fit(cat_complete_data_np)
        # encoder.fit(cat_real_data_np)


        cat_real_data_oh = encoder.transform(cat_real_data_np).toarray()
        cat_syn_data_oh = encoder.transform(cat_syn_data_np).toarray()
        cat_test_data_oh = encoder.transform(cat_test_data_np).toarray()

        num_real_data_np = num_real_data_np / num_ranges
        num_syn_data_np = num_syn_data_np / num_ranges
        num_test_data_np = num_test_data_np / num_ranges

        real_data_np = np.concatenate([num_real_data_np, cat_real_data_oh], axis=1)
        syn_data_np = np.concatenate([num_syn_data_np, cat_syn_data_oh], axis=1)
        test_data_np = np.concatenate([num_test_data_np, cat_test_data_oh], axis=1)

        device = self.device

        real_data_th = torch.tensor(real_data_np).to(device)
        syn_data_th = torch.tensor(syn_data_np).to(device)  
        test_data_th = torch.tensor(test_data_np).to(device)

        dcrs_real = []
        dcrs_test = []
        batch_size = 10000 // cat_real_data_oh.shape[1]   # This esitmation should make sure that dcr_real and dcr_test can be fit into 10GB GPU memory

        for i in tqdm(range((syn_data_th.shape[0] // batch_size) + 1)):
            if i != (syn_data_th.shape[0] // batch_size):
                batch_syn_data_th = syn_data_th[i*batch_size: (i+1) * batch_size]
            else:
                batch_syn_data_th = syn_data_th[i*batch_size:]
                
            dcr_real = (batch_syn_data_th[:, None] - real_data_th).abs().sum(dim = 2).min(dim = 1).values
            dcr_test = (batch_syn_data_th[:, None] - test_data_th).abs().sum(dim = 2).min(dim = 1).values
            dcrs_real.append(dcr_real)
            dcrs_test.append(dcr_test)
            
        dcrs_real = torch.cat(dcrs_real)
        dcrs_test = torch.cat(dcrs_test)
        
        score = (dcrs_real < dcrs_test).nonzero().shape[0] / dcrs_real.shape[0]
        
        out_metrics = {
            "dcr": score,
        }
        out_extras = {
            "dcr_real": dcrs_real.cpu().numpy(),
            "dcr_test": dcrs_test.cpu().numpy(),
        }
        return out_metrics, out_extras
        
    def evaluate_quality(self, syn_data):
        # with open(info_path, 'r') as f:
        #     info = json.load(f)
            
        info = deepcopy(self.info)
        real_path = self.real_data_path
        real_data = pd.read_csv(real_path)


        ''' Special treatment for default dataset and CoDi model '''

        real_data.columns = range(len(real_data.columns))
        syn_data.columns = range(len(syn_data.columns))

        num_col_idx = info['num_col_idx']
        cat_col_idx = info['cat_col_idx']
        target_col_idx = info['target_col_idx']
        if info['task_type'] == 'regression':
            num_col_idx += target_col_idx
        else:
            cat_col_idx += target_col_idx
            
        num_real_data = real_data[num_col_idx]
        cat_real_data = real_data[cat_col_idx]

        num_real_data_np = num_real_data.to_numpy()

        # --- Normalize categorical representations between real and synthetic ---
        # For some datasets (e.g., "default"), the real data stores categories as
        # integers (e.g., 1, 2) while the synthetic data stores them as floats
        # (e.g., 1.0, 2.0). When cast directly to strings, this becomes "1" vs
        # "1.0", which makes OneHotEncoder treat them as different categories.
        # We therefore:
        #   1) cast to string
        #   2) strip a trailing ".0" if present
        # so that "1" and "1.0" are made consistent as "1".
        # Match parent: use raw numeric columns (no coercion to median); parent would error on NaN
        num_syn_data = syn_data[num_col_idx]
        cat_syn_data = syn_data[cat_col_idx]

        def _normalize_cat_df(df: pd.DataFrame) -> np.ndarray:
            df_str = df.astype("str")
            # Strip only a SINGLE trailing ".0" (e.g., "1.0" -> "1", "10.0" -> "10")
            df_str = df_str.apply(lambda col: col.str.replace(r"\.0$", "", regex=True))
            return df_str.to_numpy()

        cat_real_data_np = _normalize_cat_df(cat_real_data)
        num_syn_data_np = num_syn_data.to_numpy()
        cat_syn_data_np = _normalize_cat_df(cat_syn_data)

        # Match parent: no handle_unknown so encoder matches real categories only
        encoder = OneHotEncoder()
        encoder.fit(cat_real_data_np)

        cat_real_data_oh = encoder.transform(cat_real_data_np).toarray()
        try:
            cat_syn_data_oh = encoder.transform(cat_syn_data_np).toarray()
        except ValueError:
            # Synthetic has unknown categories; fallback to ignore so we can still compute
            encoder_ignore = OneHotEncoder(handle_unknown="ignore")
            encoder_ignore.fit(cat_real_data_np)
            cat_syn_data_oh = encoder_ignore.transform(cat_syn_data_np).toarray()

        le_real_data = pd.DataFrame(np.concatenate((num_real_data_np, cat_real_data_oh), axis = 1)).astype(float)
        le_real_num = pd.DataFrame(num_real_data_np).astype(float)
        le_real_cat = pd.DataFrame(cat_real_data_oh).astype(float)


        le_syn_data = pd.DataFrame(np.concatenate((num_syn_data_np, cat_syn_data_oh), axis = 1)).astype(float)
        le_syn_num = pd.DataFrame(num_syn_data_np).astype(float)
        le_syn_cat = pd.DataFrame(cat_syn_data_oh).astype(float)
        
        # Fill NaN in synthetic so we can still compute metrics (e.g. when schema partially matches)
        if le_syn_data.isnull().values.any():
            nan_count = le_syn_data.isnull().sum().sum()
            print(f"Synthetic data contains {nan_count} NaN(s); filling with 0 for quality metrics.")
            le_syn_data = le_syn_data.fillna(0)
            

        np.set_printoptions(precision=4)

        result = []

        print('=========== All Features ===========')
        print('Data shape: ', le_syn_data.shape)

        # Synthcity quality evaluator requires real and synthetic to have the same length
        if len(le_syn_data) != len(le_real_data):
            n = min(len(le_syn_data), len(le_real_data))
            print(f"Subsampling to same length n={n} (real had {len(le_real_data)}, synthetic had {len(le_syn_data)})")
            le_syn_data = le_syn_data.sample(n=n, random_state=42)
            le_real_data = le_real_data.sample(n=n, random_state=42)

        X_syn_loader = GenericDataLoader(le_syn_data)
        X_real_loader = GenericDataLoader(le_real_data)

        # Only compute for 'all' features (skip separate num/cat computations)
        out_metrics = {}
        try:
            quality_evaluator = eval_statistical.AlphaPrecision()
            qual_res = quality_evaluator.evaluate(X_real_loader, X_syn_loader)
            qual_res = {
                k: v for (k, v) in qual_res.items() if "naive" in k
            }  # use the naive implementation of AlphaPrecision
            qual_score = np.mean(list(qual_res.values()))
            print('all')
            print('alpha precision: {:.6f}, beta recall: {:.6f}'.format(qual_res['delta_precision_alpha_naive'], qual_res['delta_coverage_beta_naive'] ))
        except (EOFError, FileNotFoundError) as e:
            # Handle corrupted cache file - clear cache and retry once
            import os
            import shutil
            cache_dir = os.path.expanduser('~/.cache/synthcity')
            if os.path.exists(cache_dir):
                print(f'Warning: Cache error detected ({e}). Clearing synthcity cache and retrying...')
                try:
                    shutil.rmtree(cache_dir)
                    os.makedirs(cache_dir, exist_ok=True)
                except Exception as cleanup_error:
                    print(f'Warning: Could not clear cache: {cleanup_error}')
            
            # Retry evaluation after clearing cache
            try:
                quality_evaluator = eval_statistical.AlphaPrecision()
                qual_res = quality_evaluator.evaluate(X_real_loader, X_syn_loader)
                qual_res = {
                    k: v for (k, v) in qual_res.items() if "naive" in k
                }
                qual_score = np.mean(list(qual_res.values()))
                print('all (after cache clear)')
                print('alpha precision: {:.6f}, beta recall: {:.6f}'.format(qual_res['delta_precision_alpha_naive'], qual_res['delta_coverage_beta_naive'] ))
            except Exception as retry_error:
                print(f'Error: Failed to evaluate quality after cache clear: {retry_error}')
                # Set default values to continue evaluation
                qual_res = {
                    'delta_precision_alpha_naive': 0.0,
                    'delta_coverage_beta_naive': 0.0
                }
                qual_score = 0.0

        Alpha_Precision_all = qual_res['delta_precision_alpha_naive']
        Beta_Recall_all = qual_res['delta_coverage_beta_naive']
        
        out_metrics['quality/alpha_precision_all'] = Alpha_Precision_all
        out_metrics['quality/beta_recall_all'] = Beta_Recall_all
        
        # out_metrics = {
        #     "quality/alpha_precision": Alpha_Precision_all,
        #     "quality/beta_recall": Beta_Recall_all,
        # }
        
        out_extras = {
            # "dcr_real": dcrs_real.cpu().numpy(),
            # "dcr_test": dcrs_test.cpu().numpy(),
        }

        return out_metrics, out_extras#Alpha_Precision_all, Beta_Recall_all     
    
    def plot_density(self, syn_data):
        syn_data_cp = deepcopy(syn_data)
        real_data = pd.read_csv(self.real_data_path)
        info = deepcopy(self.info)
        y_only = len(syn_data_cp.columns)==1
        if y_only:
            target_col_idx = info['target_col_idx'][0]
            target_col_name = info['column_names'][target_col_idx]
            syn_data_cp = self.complete_y_only_data(syn_data_cp, real_data, target_col_name)
        img = plot_density(syn_data_cp, real_data, info)
        return img
    
    def complete_y_only_data(self, syn_data, real_data, target_col_idx):
        syn_target_col = deepcopy(syn_data.iloc[:, 0])
        syn_data = deepcopy(real_data)
        syn_data[target_col_idx] = syn_target_col
        return syn_data
        

def reorder(real_data, syn_data, info):
    num_col_idx = deepcopy(info['num_col_idx']) # BUG: info will be modified by += in the next few lines
    cat_col_idx = deepcopy(info['cat_col_idx'])
    target_col_idx = deepcopy(info['target_col_idx'])

    task_type = info['task_type']
    if task_type == 'regression':
        num_col_idx += target_col_idx
    else:
        cat_col_idx += target_col_idx

    real_num_data = real_data[num_col_idx]
    real_cat_data = real_data[cat_col_idx]

    new_real_data = pd.concat([real_num_data, real_cat_data], axis=1)
    new_real_data.columns = range(len(new_real_data.columns))

    syn_num_data = syn_data[num_col_idx]
    syn_cat_data = syn_data[cat_col_idx]
    
    new_syn_data = pd.concat([syn_num_data, syn_cat_data], axis=1)
    new_syn_data.columns = range(len(new_syn_data.columns))

    
    metadata = info['metadata']

    columns = metadata['columns']
    metadata['columns'] = {}

    inverse_idx_mapping = info['inverse_idx_mapping']


    for i in range(len(new_real_data.columns)):
        if i < len(num_col_idx):
            metadata['columns'][i] = columns[num_col_idx[i]]
        else:
            metadata['columns'][i] = columns[cat_col_idx[i-len(num_col_idx)]]
    

    return new_real_data, new_syn_data, metadata