import os
import numpy as np
import pandas as pd
import torch
from typing import Optional, Union, List, Dict, Callable

from config import Config, get_default_config
from environment import DataGenerationEnv
from kto_controller import KTOAgent
from utils import set_seed, setup_logging, save_checkpoint
from generators.base import BaseGenerator


class TAPTrainer:
    """TAP trainer for policy learning."""

    def __init__(self, config: Optional[Config] = None, task_type: str = "classification"):
        self.config = config or get_default_config()
        self.config.data.task_type = task_type
        set_seed(self.config.train.seed)

    def train_policy(
        self,
        train_data: Union[str, pd.DataFrame],
        generator: BaseGenerator,
        target_col: str,
        num_steps: Optional[int] = None,
        final_samples: int = 500,
    ) -> pd.DataFrame:
        train_df = pd.read_csv(train_data) if isinstance(train_data, str) else train_data

        self.config.data.target_column = target_col
        if num_steps:
            self.config.train.num_steps = num_steps

        generator_name = type(generator).__name__.lower().replace('generator', '')
        checkpoint_dir = f"{self.config.train.checkpoint_dir}_{generator_name}"
        log_dir = f"{self.config.train.log_dir}_{generator_name}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        os.makedirs(log_dir, exist_ok=True)
        logger = setup_logging(log_dir)

        env = DataGenerationEnv(self.config, generator, train_df)

        logger.info(f"Task: {self.config.data.task_type}, Targets: {env.num_targets}")

        agent = KTOAgent(
            state_dim=env.get_state_dim(),
            num_targets=env.num_targets,
            num_templates=env.num_templates,
            config=self.config.kto,
            device=self.config.generator.device
        )

        best_reward, best_step = float('-inf'), 0

        for step in range(self.config.train.num_steps):
            state = env.get_state()

            M = 5
            n_sup = int(len(env.D_real) * 0.8)
            splits = [(np.random.permutation(len(env.D_real))[:n_sup],
                       np.random.permutation(len(env.D_real))[n_sup:]) for _ in range(M)]

            action = agent.select_action(state)
            reward, info = env.step(action, splits=splits)

            if info["n_passed"] > 0:
                env.proposals.append({"batch": info["passed_batch"], "ig": reward})

            loss = agent.update(state, action, reward)

            if reward > best_reward:
                best_reward, best_step = reward, step

            commit_interval = self.config.inpaint.commit_interval
            if (step + 1) % commit_interval == 0:
                n_committed = env.commit_top_proposals()
                if n_committed > 0:
                    logger.info(f"Commit: {n_committed} samples")

            if step % self.config.train.log_every == 0:
                logger.info(
                    f"Step {step:4d} | IG={reward:+.4f} | "
                    f"gen={info['n_generated']} pass={info['n_passed']} | "
                    f"best={best_reward:+.4f}@{best_step} | loss={loss:.4f}"
                )

            if len(env.synthetic_buffer) >= final_samples:
                logger.info(f"Reached target: {final_samples} samples")
                break

        if len(env.proposals) > 0:
            n_committed = env.commit_top_proposals()
            logger.info(f"Final commit: {n_committed} samples")

        final_data = env.get_synthetic_data()
        if len(final_data) > final_samples:
            final_data = final_data.sample(n=final_samples, random_state=42)

        save_checkpoint(checkpoint_dir, step, agent, {'num_targets': env.num_targets}, best_reward)

        logger.info(f"Training complete. Output: {len(final_data)} samples, best_R={best_reward:+.4f}@{best_step}")

        return final_data


class TAPGenerator:
    """TAP generator for inference and evaluation."""

    def __init__(self, checkpoint: str, generator: BaseGenerator, config: Optional[Config] = None):
        self.config = config or get_default_config()
        self.generator = generator

        if generator.conditional_col:
            self.config.data.target_column = generator.conditional_col

        self.target_col = self.config.data.target_column
        self.columns = [col for col in generator.columns if col != self.target_col]

        ckpt = torch.load(checkpoint, map_location=self.config.generator.device, weights_only=False)

        state_dim = ckpt['policy_state_dict']['shared_net.0.weight'].shape[1]
        if 'target_head.weight' in ckpt['policy_state_dict']:
            num_targets = ckpt['policy_state_dict']['target_head.weight'].shape[0]
        else:
            num_targets = ckpt['policy_state_dict']['class_head.weight'].shape[0]
        num_templates = ckpt['policy_state_dict']['template_head.weight'].shape[0]

        self.agent = KTOAgent(
            state_dim, num_targets, num_templates,
            self.config.kto, self.config.generator.device
        )
        self.agent.policy.load_state_dict(ckpt['policy_state_dict'])

    def generate(self, real_data: Union[str, pd.DataFrame], n_syn: int = 100) -> pd.DataFrame:
        real_df = pd.read_csv(real_data) if isinstance(real_data, str) else real_data.copy()
        env = DataGenerationEnv(self.config, self.generator, real_df)

        while len(env.synthetic_buffer) < n_syn:
            state = env.get_state()
            action = self.agent.select_action(state)
            env.step(action)

        result = env.get_synthetic_data()
        return result.sample(n=n_syn, random_state=42) if len(result) > n_syn else result

    def evaluate(
        self,
        real_data: Union[str, pd.DataFrame],
        synthetic_data: pd.DataFrame,
        test_data: Union[str, pd.DataFrame],
    ):
        if self.config.data.task_type == "regression":
            return self._evaluate_regression(real_data, synthetic_data, test_data)
        else:
            return self._evaluate_classification(real_data, synthetic_data, test_data)

    def _preprocess_data(
        self,
        real_df: pd.DataFrame,
        test_df: pd.DataFrame,
        syn_df: pd.DataFrame,
        is_regression: bool
    ) -> Dict:
        from sklearn.preprocessing import LabelEncoder

        X_real = real_df[self.columns].copy()
        X_test = test_df[self.columns].copy()

        num_cols = [c for c in self.columns if pd.api.types.is_numeric_dtype(X_real[c])]
        cat_cols = [c for c in self.columns if c not in num_cols]

        num_medians = {c: X_real[c].median() for c in num_cols}
        cat_modes = {c: X_real[c].mode().iloc[0] if len(X_real[c].mode()) > 0 else 'unknown' for c in cat_cols}

        for c in num_cols:
            X_real[c] = X_real[c].fillna(num_medians[c])
            X_test[c] = X_test[c].fillna(num_medians[c])
        for c in cat_cols:
            X_real[c] = X_real[c].fillna(cat_modes[c])
            X_test[c] = X_test[c].fillna(cat_modes[c])

        if is_regression:
            y_real = real_df[self.target_col].values.astype(float)
            y_test = test_df[self.target_col].values.astype(float)
            y_mean = y_real.mean()
            y_std = y_real.std() + 1e-8
            y_real = (y_real - y_mean) / y_std
            y_test = (y_test - y_mean) / y_std
            test_mask = np.ones(len(y_test), dtype=bool)
            remap, le = None, None
        else:
            le = LabelEncoder()
            le.fit(np.concatenate([real_df[self.target_col], test_df[self.target_col]]))
            y_real = le.transform(real_df[self.target_col])
            y_test = le.transform(test_df[self.target_col])

            unique_train = np.unique(y_real)
            remap = {old: new for new, old in enumerate(unique_train)}
            y_real = np.array([remap[y] for y in y_real])
            test_mask = np.array([y in remap for y in y_test])
            y_test = np.array([remap.get(y, -1) for y in y_test])

        global_mean = y_real.mean()
        cat_encodings = {}
        for c in cat_cols:
            df_tmp = pd.DataFrame({'cat': X_real[c].astype(str), 'y': y_real})
            cat_mean = df_tmp.groupby('cat')['y'].mean().to_dict()
            cat_encodings[c] = (cat_mean, global_mean)

            cat_sum = df_tmp.groupby('cat')['y'].transform('sum')
            cat_count = df_tmp.groupby('cat')['y'].transform('count')
            X_real[c] = np.where(cat_count > 1, (cat_sum - y_real) / (cat_count - 1), global_mean)
            X_test[c] = X_test[c].astype(str).map(cat_mean).fillna(global_mean)

        means = {c: X_real[c].mean() for c in self.columns}
        stds = {c: X_real[c].std() + 1e-8 for c in self.columns}
        for c in self.columns:
            X_real[c] = (X_real[c] - means[c]) / stds[c]
            X_test[c] = (X_test[c] - means[c]) / stds[c]

        X_real = X_real.values.astype(np.float64)
        X_test = X_test[test_mask].values.astype(np.float64)
        y_test = y_test[test_mask] if not is_regression else y_test

        if len(syn_df) > 0:
            X_syn = syn_df[self.columns].copy()

            if is_regression:
                y_syn = syn_df[self.target_col].values.astype(float)
                y_syn = (y_syn - y_mean) / y_std
                syn_mask = np.ones(len(y_syn), dtype=bool)
            else:
                y_syn_raw = syn_df[self.target_col].values
                le_classes_str = le.classes_.astype(str)
                y_syn_raw_str = np.array([str(y) for y in y_syn_raw])
                valid = np.isin(y_syn_raw_str, le_classes_str)
                X_syn, y_syn = X_syn[valid], le.transform(y_syn_raw_str[valid])
                y_syn = np.array([remap.get(y, -1) for y in y_syn])
                syn_mask = y_syn >= 0
                X_syn, y_syn = X_syn.iloc[syn_mask], y_syn[syn_mask]

            for c in num_cols:
                X_syn[c] = pd.to_numeric(X_syn[c], errors='coerce').fillna(num_medians[c])
            for c in cat_cols:
                cat_mean, gm = cat_encodings[c]
                X_syn[c] = X_syn[c].fillna(cat_modes[c]).astype(str).map(cat_mean).fillna(gm)
            for c in self.columns:
                X_syn[c] = (X_syn[c] - means[c]) / stds[c]

            X_syn = X_syn.values.astype(np.float64)
            X_aug = np.vstack([X_real, X_syn])
            y_aug = np.concatenate([y_real, y_syn])
        else:
            X_aug, y_aug = X_real, y_real

        return {
            'X_real': X_real, 'y_real': y_real,
            'X_test': X_test, 'y_test': y_test,
            'X_aug': X_aug, 'y_aug': y_aug
        }

    def _run_models(
        self,
        models: List[tuple],
        data: Dict,
        metric_fn: Callable,
    ) -> Dict[str, Dict]:
        results = {name: {'real': None, 'aug': None} for name, _ in models}

        np.random.seed(42)
        for name, model_cls in models:
            try:
                model = model_cls(42)
                model.fit(data['X_real'], data['y_real'])
                pred_real = model.predict(data['X_test'])

                model = model_cls(42)
                model.fit(data['X_aug'], data['y_aug'])
                pred_aug = model.predict(data['X_test'])

                results[name]['real'] = metric_fn(data['y_test'], pred_real)
                results[name]['aug'] = metric_fn(data['y_test'], pred_aug)
            except Exception as e:
                print(f"[{name}] Error: {e}")

        return results

    def _evaluate_classification(
        self,
        real_data: Union[str, pd.DataFrame],
        synthetic_data: pd.DataFrame,
        test_data: Union[str, pd.DataFrame],
    ):
        from sklearn.metrics import balanced_accuracy_score, accuracy_score
        from sklearn.linear_model import LogisticRegression
        from sklearn.neural_network import MLPClassifier
        from sklearn.ensemble import RandomForestClassifier
        from lightgbm import LGBMClassifier
        from xgboost import XGBClassifier
        from sklearn.neighbors import KNeighborsClassifier
        from tabpfn import TabPFNClassifier

        real_df = pd.read_csv(real_data) if isinstance(real_data, str) else real_data.copy()
        test_df = pd.read_csv(test_data) if isinstance(test_data, str) else test_data.copy()

        data = self._preprocess_data(real_df, test_df, synthetic_data, is_regression=False)

        models = [
            ("LR", lambda s: LogisticRegression(max_iter=1000, random_state=s)),
            ("KNN", lambda s: KNeighborsClassifier(n_neighbors=5)),
            ("MLP", lambda s: MLPClassifier(hidden_layer_sizes=(100,), max_iter=500, random_state=s)),
            ("RF", lambda s: RandomForestClassifier(n_estimators=100, random_state=s)),
            ("XGB", lambda s: XGBClassifier(n_estimators=100, random_state=s, verbosity=0, use_label_encoder=False)),
            ("LGBM", lambda s: LGBMClassifier(n_estimators=100, random_state=s, verbose=-1)),
            ("TabPFN", lambda s: TabPFNClassifier(ignore_pretraining_limits=True)),
        ]

        acc_results = self._run_models(models, data, accuracy_score)
        bacc_results = self._run_models(models, data, balanced_accuracy_score)

        print(f"{'Model':<8} | {'Acc(Real)':>10} | {'Acc(+TAP)':>10} | {'BAcc(Real)':>11} | {'BAcc(+TAP)':>11}")
        print("-" * 60)

        for name, _ in models:
            acc_r = acc_results[name]['real']
            acc_a = acc_results[name]['aug']
            bacc_r = bacc_results[name]['real']
            bacc_a = bacc_results[name]['aug']

            if acc_r is not None:
                print(f"{name:<8} | {acc_r*100:10.2f} | {acc_a*100:10.2f} | {bacc_r*100:11.2f} | {bacc_a*100:11.2f}")

    def _evaluate_regression(
        self,
        real_data: Union[str, pd.DataFrame],
        synthetic_data: pd.DataFrame,
        test_data: Union[str, pd.DataFrame],
    ):
        from sklearn.metrics import mean_squared_error, r2_score
        from sklearn.linear_model import Ridge
        from sklearn.neural_network import MLPRegressor
        from sklearn.ensemble import RandomForestRegressor
        from lightgbm import LGBMRegressor
        from xgboost import XGBRegressor
        from sklearn.neighbors import KNeighborsRegressor

        real_df = pd.read_csv(real_data) if isinstance(real_data, str) else real_data.copy()
        test_df = pd.read_csv(test_data) if isinstance(test_data, str) else test_data.copy()

        data = self._preprocess_data(real_df, test_df, synthetic_data, is_regression=True)

        models = [
            ("Ridge", lambda s: Ridge(alpha=1.0, random_state=s)),
            ("KNN", lambda s: KNeighborsRegressor(n_neighbors=5)),
            ("MLP", lambda s: MLPRegressor(hidden_layer_sizes=(100,), max_iter=500, random_state=s)),
            ("RF", lambda s: RandomForestRegressor(n_estimators=100, random_state=s)),
            ("XGB", lambda s: XGBRegressor(n_estimators=100, random_state=s, verbosity=0)),
            ("LGBM", lambda s: LGBMRegressor(n_estimators=100, random_state=s, verbose=-1)),
        ]

        try:
            from tabpfn import TabPFNRegressor
            models.append(("TabPFN", lambda s: TabPFNRegressor(ignore_pretraining_limits=True)))
        except ImportError:
            pass

        rmse_fn = lambda y, p: np.sqrt(mean_squared_error(y, p))
        rmse_results = self._run_models(models, data, rmse_fn)
        r2_results = self._run_models(models, data, r2_score)

        print(f"{'Model':<8} | {'RMSE(Real)':>11} | {'RMSE(+TAP)':>11} | {'R2(Real)':>10} | {'R2(+TAP)':>10}")
        print("-" * 60)

        for name, _ in models:
            rmse_r = rmse_results[name]['real']
            rmse_a = rmse_results[name]['aug']
            r2_r = r2_results[name]['real']
            r2_a = r2_results[name]['aug']

            if rmse_r is not None:
                print(f"{name:<8} | {rmse_r:11.4f} | {rmse_a:11.4f} | {r2_r*100:10.2f} | {r2_a*100:10.2f}")
