import numpy as np
import pandas as pd
from typing import Tuple, Dict, List
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors
import warnings
warnings.filterwarnings('ignore')

from config import Config
from generators.base import BaseGenerator


class DataGenerationEnv:
    def __init__(self, config: Config, generator: BaseGenerator, train_data: pd.DataFrame):
        self.config = config
        self.generator = generator
        self.task_type = config.data.task_type
        self.is_regression = (self.task_type == "regression")
        self.target_col = config.data.target_column
        self.columns = [col for col in train_data.columns if col != self.target_col]

        self.train_data_raw = train_data.copy()
        self.label_encoder = None
        self.feature_encoders = {}
        self._encode_train_data()

        if self.is_regression:
            self._init_bins()
            self.num_targets = self.num_bins
        else:
            self.classes = np.unique(self.train_data_raw[self.target_col])
            self.num_classes = len(self.classes)
            self.num_targets = self.num_classes

        self.D_real = self.train_data_raw.copy()
        self._original_train_data = train_data.copy()

        self._init_tabpfn()
        self._fit_tabpfn()
        self._init_stats()

        self.templates = list(config.mask_template.templates.keys())
        self.num_templates = len(self.templates)
        self.anchor_rules = config.inpaint.anchor_rules
        self.num_anchor_rules = len(self.anchor_rules)

        self.synthetic_buffer = pd.DataFrame()
        self.proposals = []
        self.gate_stats = {t: {"pass": 0, "total": 0} for t in self.templates}
        self._compute_column_importance()

    def _init_bins(self):
        y = self.train_data_raw[self.target_col].values.astype(float)
        self.num_bins = self.config.inpaint.num_bins
        self.bin_edges = np.percentile(y, np.linspace(0, 100, self.num_bins + 1))
        self.bin_edges[0] = -np.inf
        self.bin_edges[-1] = np.inf

    def _get_bin(self, y):
        y = np.asarray(y).astype(float)
        return np.clip(np.digitize(y, self.bin_edges) - 1, 0, self.num_bins - 1)

    def _encode_train_data(self):
        if self.is_regression:
            self.train_data_raw[self.target_col] = pd.to_numeric(
                self.train_data_raw[self.target_col], errors='coerce').astype('float64')
        else:
            if self.train_data_raw[self.target_col].dtype == 'object':
                self.label_encoder = LabelEncoder()
                self.label_encoder.fit(self.train_data_raw[self.target_col])
                self.train_data_raw[self.target_col] = self.label_encoder.transform(
                    self.train_data_raw[self.target_col])

        for col in self.columns:
            try:
                self.train_data_raw[col] = pd.to_numeric(self.train_data_raw[col], errors='raise').astype('float64')
                if self.train_data_raw[col].isna().any():
                    self.train_data_raw[col] = self.train_data_raw[col].fillna(self.train_data_raw[col].mean())
            except (ValueError, TypeError):
                le = LabelEncoder()
                self.train_data_raw[col] = le.fit_transform(self.train_data_raw[col].astype(str)).astype('float64')
                self.feature_encoders[col] = le

    def _init_tabpfn(self):
        if self.is_regression:
            try:
                from tabpfn import TabPFNRegressor
                self.tabpfn = TabPFNRegressor(ignore_pretraining_limits=True)
            except ImportError:
                from sklearn.ensemble import GradientBoostingRegressor
                self.tabpfn = GradientBoostingRegressor(n_estimators=100, random_state=42)
        else:
            try:
                from tabpfn import TabPFNClassifier
                self.tabpfn = TabPFNClassifier(ignore_pretraining_limits=True)
            except ImportError:
                from lightgbm import LGBMClassifier
                self.tabpfn = LGBMClassifier(n_estimators=100, random_state=42, verbose=-1)

    def _fit_tabpfn(self):
        X = self.D_real[self.columns].values
        y = self.D_real[self.target_col].values
        self.tabpfn.fit(X, y)

    def _init_stats(self):
        y = self.D_real[self.target_col].values
        X = self.D_real[self.columns].values

        if self.is_regression:
            bins = self._get_bin(y)
            self.target_counts = np.bincount(bins, minlength=self.num_targets)
            preds = self.tabpfn.predict(X)
            residuals = np.abs(y - preds)
            self.target_nll = np.zeros(self.num_targets)
            self.target_error = np.zeros(self.num_targets)
            for b in range(self.num_targets):
                mask = (bins == b)
                if mask.sum() > 0:
                    self.target_nll[b] = residuals[mask].mean()
                    self.target_error[b] = residuals[mask].mean()
            self.target_per_target = len(y) // self.num_targets
            self._residual_threshold = np.percentile(residuals, self.config.inpaint.residual_threshold_percentile)
        else:
            self.target_counts = np.bincount(y.astype(int), minlength=self.num_targets)
            probs = self.tabpfn.predict_proba(X)
            preds = np.argmax(probs, axis=1)
            self.target_error = np.zeros(self.num_targets)
            self.target_nll = np.zeros(self.num_targets)
            for c in range(self.num_targets):
                mask = (y == c)
                if mask.sum() > 0:
                    self.target_error[c] = (preds[mask] != c).mean()
                    self.target_nll[c] = -np.log(probs[mask, c].clip(1e-10, 1)).mean()
            self.target_per_target = len(y) // self.num_targets
        self.diversity_score = 0.0

    def _compute_column_importance(self):
        if self.is_regression:
            from sklearn.feature_selection import mutual_info_regression
            mi_func = mutual_info_regression
        else:
            from sklearn.feature_selection import mutual_info_classif
            mi_func = mutual_info_classif

        X = self.D_real[self.columns].values
        y = self.D_real[self.target_col].values
        mi = mi_func(X, y, random_state=42)
        correlated_idx = np.argsort(mi)[-max(3, len(self.columns)//4):]

        stabilities = []
        for col_idx in range(len(self.columns)):
            vals = [X[np.random.choice(len(X), len(X), replace=True), col_idx].std() for _ in range(10)]
            stabilities.append(np.std(vals))
        stable_idx = np.argsort(stabilities)[:max(3, len(self.columns)//4)]

        self.correlated_cols = set(correlated_idx.tolist())
        self.stable_cols = set(stable_idx.tolist())

    def get_state(self) -> np.ndarray:
        current = self.target_counts.copy()
        if len(self.synthetic_buffer) > 0:
            encoded = self._encode_data(self.synthetic_buffer)
            if len(encoded) > 0:
                syn_y = encoded[self.target_col].values
                if self.is_regression:
                    syn_bins = self._get_bin(syn_y)
                    current = current + np.bincount(syn_bins, minlength=self.num_targets)
                else:
                    current = current + np.bincount(syn_y.astype(int), minlength=self.num_targets)

        deficit = np.clip((self.target_per_target - current) / max(self.target_per_target, 1), -1, 1)
        gate_rates = [self.gate_stats[t]["pass"] / max(self.gate_stats[t]["total"], 1) for t in self.templates]
        return np.concatenate([deficit, self.target_nll, np.array(gate_rates), [self.diversity_score]]).astype(np.float32)

    def get_state_dim(self) -> int:
        return 2 * self.num_targets + self.num_templates + 1

    def _determine_anchor_rule(self, target_idx: int) -> str:
        nll = self.target_nll[target_idx]
        error = self.target_error[target_idx]
        if nll > np.median(self.target_nll):
            return "high_uncertainty"
        if error > np.median(self.target_error):
            return "high_error"
        return "random"

    def step(self, action, splits: List = None) -> Tuple[float, Dict]:
        anchor_rule = self._determine_anchor_rule(action.target_class)
        anchor_rule_idx = self.anchor_rules.index(anchor_rule) if anchor_rule in self.anchor_rules else 3
        anchor_indices = self._select_anchor_indices(action.target_class, anchor_rule_idx)
        if len(anchor_indices) == 0:
            return 0.0, {"n_generated": 0, "n_passed": 0, "delta_U": 0.0, "passed_batch": pd.DataFrame()}

        template = self.templates[action.template_id]
        num_mask, cat_mask = self._build_mask(template, action.explore_level)

        try:
            generated = self.generator.sample_inpaint(
                anchor_indices=anchor_indices, num_mask=num_mask, cat_mask=cat_mask,
                n_samples_per_anchor=self.config.inpaint.samples_per_anchor,
                stochasticity=action.explore_level * 0.5)
        except Exception:
            return 0.0, {"n_generated": 0, "n_passed": 0, "delta_U": 0.0, "passed_batch": pd.DataFrame()}

        if len(generated) == 0:
            return 0.0, {"n_generated": 0, "n_passed": 0, "delta_U": 0.0, "passed_batch": pd.DataFrame()}

        passed = self._apply_gates(generated)
        self.gate_stats[template]["total"] += len(generated)
        self.gate_stats[template]["pass"] += len(passed)

        if len(passed) == 0:
            return 0.0, {"n_generated": len(generated), "n_passed": 0, "delta_U": 0.0, "passed_batch": pd.DataFrame()}

        reward, delta_U = self._compute_reward(passed, splits=splits)
        return reward, {"n_generated": len(generated), "n_passed": len(passed), "delta_U": delta_U, "passed_batch": passed}

    def _select_anchor_indices(self, target_idx: int, anchor_rule: int) -> np.ndarray:
        y = self.D_real[self.target_col].values
        if self.is_regression:
            mask = (self._get_bin(y) == target_idx)
        else:
            mask = (y == target_idx)
        candidate_indices = np.where(mask)[0]
        if len(candidate_indices) == 0:
            return np.array([], dtype=int)

        n_anchors = min(self.config.inpaint.samples_per_step // self.config.inpaint.samples_per_anchor, len(candidate_indices))
        rule = self.anchor_rules[anchor_rule]

        if rule == "high_uncertainty":
            X = self.D_real.iloc[candidate_indices][self.columns].values
            if self.is_regression:
                preds = self.tabpfn.predict(X)
                residuals = np.abs(y[candidate_indices] - preds)
                selected_local = np.argsort(residuals)[-n_anchors:]
            else:
                probs = self.tabpfn.predict_proba(X)
                entropy = -np.sum(probs * np.log(probs.clip(1e-10, 1)), axis=1)
                selected_local = np.argsort(entropy)[-n_anchors:]
            return candidate_indices[selected_local]
        elif rule == "high_error":
            X = self.D_real.iloc[candidate_indices][self.columns].values
            preds = self.tabpfn.predict(X)
            if self.is_regression:
                residuals = np.abs(y[candidate_indices] - preds)
                selected_local = np.argsort(residuals)[-n_anchors:]
            else:
                wrong_local = np.where(preds != y[candidate_indices])[0]
                if len(wrong_local) >= n_anchors:
                    return candidate_indices[wrong_local[:n_anchors]]
                selected_local = np.arange(min(n_anchors, len(candidate_indices)))
            return candidate_indices[selected_local]
        return np.random.choice(candidate_indices, size=min(n_anchors, len(candidate_indices)), replace=False)

    def _build_mask(self, template_name: str, strength: float) -> Tuple[list, list]:
        try:
            num_mask, cat_mask = self.generator.get_column_masks([])
            d_num, d_cat = len(num_mask), len(cat_mask)
        except:
            return [True], [True]

        num_mask = [True] * d_num
        cat_mask = [True] * d_cat
        all_fixed = self.correlated_cols | self.stable_cols if template_name == "conservative" else set()

        for col_idx in all_fixed:
            if col_idx < d_num:
                num_mask[col_idx] = False
            else:
                cat_mask[col_idx - d_num] = False

        if strength < 1.0:
            n_gen = sum(num_mask) + sum(cat_mask)
            n_release = int((1 - strength) * n_gen * 0.3)
            gen_idx = [i for i in range(d_num) if num_mask[i]]
            if len(gen_idx) > n_release and n_release > 0:
                for i in np.random.choice(gen_idx, n_release, replace=False):
                    num_mask[i] = False
        return num_mask, cat_mask

    def _apply_gates(self, generated: pd.DataFrame) -> pd.DataFrame:
        if len(generated) == 0:
            return generated
        encoded = self._encode_data(generated)
        if len(encoded) == 0:
            return generated.iloc[:0]

        if self.is_regression:
            passed = self._apply_gates_regression(generated, encoded)
        else:
            passed = self._apply_gates_classification(generated, encoded)

        if len(passed) > 0:
            passed = self._diversity_gate(passed)
        return passed

    def _apply_gates_classification(self, generated: pd.DataFrame, encoded: pd.DataFrame) -> pd.DataFrame:
        X = encoded[self.columns].values
        y = encoded[self.target_col].values.astype(int)
        probs = self.tabpfn.predict_proba(X)
        keep = []
        for i in range(len(probs)):
            if y[i] >= probs.shape[1]:
                continue
            p_y = probs[i, y[i]]
            p_others = np.delete(probs[i], y[i])
            margin = p_y - p_others.max() if len(p_others) > 0 else p_y
            if p_y >= self.config.inpaint.label_p_min and margin >= self.config.inpaint.label_margin_threshold:
                keep.append(i)
        return generated.iloc[keep].reset_index(drop=True)

    def _apply_gates_regression(self, generated: pd.DataFrame, encoded: pd.DataFrame) -> pd.DataFrame:
        X = encoded[self.columns].values
        y = encoded[self.target_col].values.astype(float)
        preds = self.tabpfn.predict(X)
        residuals = np.abs(y - preds)
        keep = residuals <= self._residual_threshold
        return generated.iloc[keep].reset_index(drop=True)

    def _diversity_gate(self, df: pd.DataFrame) -> pd.DataFrame:
        encoded = self._encode_data(df)
        if len(encoded) == 0:
            return df.iloc[:0]

        X_syn = encoded[self.columns].values
        y_syn = encoded[self.target_col].values
        X_real = self.D_real[self.columns].values
        y_real = self.D_real[self.target_col].values

        if not hasattr(self, '_diversity_thresholds'):
            nn = NearestNeighbors(n_neighbors=2)
            nn.fit(X_real)
            dist, _ = nn.kneighbors(X_real)
            real_real_dist = dist[:, 1] / np.sqrt(X_real.shape[1])
            self._diversity_thresholds = {}
            targets = self._get_bin(y_real) if self.is_regression else y_real
            for t in range(self.num_targets):
                mask = (targets == t)
                self._diversity_thresholds[t] = np.percentile(real_real_dist[mask], 5) if mask.sum() > 1 else 0.02

        nn = NearestNeighbors(n_neighbors=1)
        nn.fit(X_real)
        dist_syn_real, _ = nn.kneighbors(X_syn)
        norm_dist = dist_syn_real.flatten() / np.sqrt(X_syn.shape[1])

        keep_mask = np.zeros(len(df), dtype=bool)
        for i in range(len(df)):
            t = int(self._get_bin(y_syn[i])) if self.is_regression else int(y_syn[i])
            if norm_dist[i] > self._diversity_thresholds.get(t, 0.05):
                keep_mask[i] = True

        passed = df.iloc[keep_mask].reset_index(drop=True)
        if len(passed) > 1:
            X_passed = self._encode_data(passed)[self.columns].values
            nn_syn = NearestNeighbors(n_neighbors=2)
            nn_syn.fit(X_passed)
            dist_syn_syn, _ = nn_syn.kneighbors(X_passed)
            dup_mask = dist_syn_syn[:, 1] / np.sqrt(X_passed.shape[1]) > 0.01
            passed = passed.iloc[dup_mask].reset_index(drop=True)
        return passed

    def _encode_data(self, df: pd.DataFrame) -> pd.DataFrame:
        encoded = df.copy()
        if not self.is_regression and self.label_encoder is not None and self.target_col in encoded.columns:
            valid = encoded[self.target_col].isin(self.label_encoder.classes_)
            encoded = encoded[valid]
            if len(encoded) > 0:
                encoded[self.target_col] = self.label_encoder.transform(encoded[self.target_col])
        if self.is_regression and self.target_col in encoded.columns:
            encoded[self.target_col] = pd.to_numeric(encoded[self.target_col], errors='coerce')

        for col in self.columns:
            if col in encoded.columns:
                if col in self.feature_encoders:
                    le = self.feature_encoders[col]
                    valid = encoded[col].astype(str).isin(le.classes_)
                    encoded = encoded[valid]
                    if len(encoded) > 0:
                        encoded[col] = le.transform(encoded[col].astype(str)).astype('float64')
                else:
                    encoded[col] = pd.to_numeric(encoded[col], errors='coerce').astype('float64')
        return encoded.dropna().reset_index(drop=True)

    def _compute_reward(self, batch: pd.DataFrame, splits: List = None, M: int = 5, top_ratio: float = 0.2) -> Tuple[float, float]:
        if self.is_regression:
            return self._compute_reward_regression(batch, splits, M, top_ratio)
        return self._compute_reward_classification(batch, splits, M, top_ratio)

    def _compute_reward_classification(self, batch: pd.DataFrame, splits: List = None, M: int = 5, top_entropy_ratio: float = 0.2) -> Tuple[float, float]:
        if splits is None:
            n_sup = int(len(self.D_real) * 0.8)
            splits = [(np.random.permutation(len(self.D_real))[:n_sup], np.random.permutation(len(self.D_real))[n_sup:]) for _ in range(M)]

        encoded = self._encode_data(batch)
        if len(encoded) == 0:
            return 0.0, 0.0

        X_batch, y_batch = encoded[self.columns].values, encoded[self.target_col].values
        X_real, y_real = self.D_real[self.columns].values, self.D_real[self.target_col].values

        delta_Us = []
        for sup_idx, qry_idx in splits:
            try:
                self.tabpfn.fit(X_real[sup_idx], y_real[sup_idx])
                probs_0 = self.tabpfn.predict_proba(X_real[qry_idx])
                entropy = -np.sum(probs_0 * np.log(probs_0.clip(1e-10, 1)), axis=1)
                hi_idx = np.argsort(entropy)[-max(1, int(len(entropy) * top_entropy_ratio)):]
                U0 = np.mean(np.log(probs_0[hi_idx, y_real[qry_idx][hi_idx].astype(int)].clip(1e-10, 1)))

                self.tabpfn.fit(np.vstack([X_real[sup_idx], X_batch]), np.concatenate([y_real[sup_idx], y_batch]))
                probs_B = self.tabpfn.predict_proba(X_real[qry_idx])
                U_B = np.mean(np.log(probs_B[hi_idx, y_real[qry_idx][hi_idx].astype(int)].clip(1e-10, 1)))
                delta_Us.append(U_B - U0)
            except:
                delta_Us.append(0.0)

        self._fit_tabpfn()
        return float(np.mean(delta_Us)), float(np.mean(delta_Us))

    def _compute_reward_regression(self, batch: pd.DataFrame, splits: List = None, M: int = 5, top_var_ratio: float = 0.2) -> Tuple[float, float]:
        if splits is None:
            n_sup = int(len(self.D_real) * 0.8)
            splits = [(np.random.permutation(len(self.D_real))[:n_sup], np.random.permutation(len(self.D_real))[n_sup:]) for _ in range(M)]

        encoded = self._encode_data(batch)
        if len(encoded) == 0:
            return 0.0, 0.0

        X_batch, y_batch = encoded[self.columns].values, encoded[self.target_col].values.astype(float)
        X_real, y_real = self.D_real[self.columns].values, self.D_real[self.target_col].values.astype(float)

        delta_Us = []
        for sup_idx, qry_idx in splits:
            try:
                self.tabpfn.fit(X_real[sup_idx], y_real[sup_idx])
                preds_0 = self.tabpfn.predict(X_real[qry_idx])
                residuals_0 = np.abs(y_real[qry_idx] - preds_0)
                hi_idx = np.argsort(residuals_0)[-max(1, int(len(residuals_0) * top_var_ratio)):]
                U0 = -np.mean((y_real[qry_idx][hi_idx] - preds_0[hi_idx]) ** 2)

                self.tabpfn.fit(np.vstack([X_real[sup_idx], X_batch]), np.concatenate([y_real[sup_idx], y_batch]))
                preds_B = self.tabpfn.predict(X_real[qry_idx])
                U_B = -np.mean((y_real[qry_idx][hi_idx] - preds_B[hi_idx]) ** 2)
                delta_Us.append(U_B - U0)
            except:
                delta_Us.append(0.0)

        self._fit_tabpfn()
        return float(np.mean(delta_Us)), float(np.mean(delta_Us))

    def _update_diversity(self):
        if len(self.synthetic_buffer) == 0:
            self.diversity_score = 0.0
            return
        encoded = self._encode_data(self.synthetic_buffer)
        if len(encoded) == 0:
            self.diversity_score = 0.0
            return
        X_syn = encoded[self.columns].values
        X_real = self.D_real[self.columns].values
        nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree')
        nn.fit(X_real)
        dist, _ = nn.kneighbors(X_syn)
        self.diversity_score = float(dist.mean() / np.sqrt(X_syn.shape[1]))

    def commit_top_proposals(self, top_l: int = 20) -> int:
        if len(self.proposals) == 0:
            return 0
        sorted_props = sorted(self.proposals, key=lambda x: x["ig"], reverse=True)[:top_l]
        all_batches = [p["batch"] for p in sorted_props if len(p["batch"]) > 0]
        if len(all_batches) == 0:
            self.proposals = []
            return 0

        merged = pd.concat(all_batches, ignore_index=True)
        if len(merged) > 1:
            encoded = self._encode_data(merged)
            if len(encoded) > 1:
                X = encoded[self.columns].values
                nn = NearestNeighbors(n_neighbors=min(2, len(X)))
                nn.fit(X)
                dist, _ = nn.kneighbors(X)
                if dist.shape[1] > 1:
                    keep = dist[:, 1] / np.sqrt(X.shape[1]) > 0.01
                    merged = merged.iloc[keep].reset_index(drop=True)

        committed = len(merged)
        if committed > 0:
            self.synthetic_buffer = pd.concat([self.synthetic_buffer, merged], ignore_index=True)
            self._update_diversity()
        self.proposals = []
        return committed

    def get_synthetic_data(self) -> pd.DataFrame:
        return self.synthetic_buffer.copy()

    def decode_data(self, df: pd.DataFrame) -> pd.DataFrame:
        if df is None or len(df) == 0:
            return df
        decoded = df.copy()
        if not self.is_regression and self.label_encoder is not None and self.target_col in decoded.columns:
            try:
                labels = decoded[self.target_col].values.astype(int)
                valid = (labels >= 0) & (labels < len(self.label_encoder.classes_))
                decoded = decoded[valid].copy()
                if len(decoded) > 0:
                    decoded[self.target_col] = self.label_encoder.inverse_transform(decoded[self.target_col].values.astype(int))
            except:
                pass
        for col, encoder in self.feature_encoders.items():
            if col in decoded.columns:
                try:
                    values = np.clip(decoded[col].values.astype(int), 0, len(encoder.classes_) - 1)
                    decoded[col] = encoder.inverse_transform(values)
                except:
                    pass
        return decoded.reset_index(drop=True)
