import copy
import enlighten
import numpy as np
import scipy as sp
from loguru import logger
from abc import ABC, abstractmethod
from argparse import ArgumentParser
from typing import Dict, Callable, Optional, List, Tuple
import matplotlib.pyplot as plt

from src.data_augmentors.simulation.linear import NullSpaceTranslation as DA
from src.sem.simulation.linear import LinearSimulationSEM as SEM

from src.methods.abstract import pointEstimator as Regressor
from src.methods.regression import LeastSquaresClosedForm as ERM
from src.methods.sensitivity_models import (
    PartialR2,
    InvarianceConstrainedPartialR2 as invPartialR2,
)

from src.experiments.utils import (
    save,
    set_seed,
    tex_table,
    fit_model,
    sweep_plot,
    ANNOTATE_SWEEP_PLOT,
    radial_sweep_pcs,
    sweep_along_pc,
    # panel helpers / constants
    RC_PARAMS,
    TEX_MAPPER,
    FS_LABEL,
    FS_TICK,
    color_map,
    PLOT_FORMAT,
    PLOT_DPI,
)

ModelBuilder = Callable[[Optional[float]], Regressor]

MANAGER = enlighten.get_manager()
EXPERIMENT: str='linear_simulation'
DEFAULT_QUERY_JOBS: int=1
EPSILON: float=2**0
GAMMA0: float=2**9
GAMMA: float=2**9
KAPPA: float=2**2.5


class SweepExperiment:
    def __init__(
            self,
            seed: int,
            n_samples: int,
            kernel_dim: int,
            n_experiments: int,
            sweep_samples: int,
            methods: Dict[str, Callable[[Optional[float]], Regressor]],
            hyperparameters: Optional[Dict[str, Dict[str, float]]]=None
        ):
        self.seed = seed
        self.n_samples = n_samples
        self.kernel_dim = kernel_dim
        self.n_experiments = n_experiments
        self.sweep_samples = sweep_samples
        self.methods = methods
        self.hyperparameters = hyperparameters
    
    @staticmethod
    def fit(
            method_name: str,
            method: Callable[[Optional[str]], Regressor],
            X, y, GX,
            param: float=10.0,
            da: Optional[DA]=None,
            hyperparameters: Optional[Dict[str, Dict[str, float]]]=None
        ) -> Regressor:
        model = method()
        fit_model(
            model=model,
            name=method_name,
            X=X, y=y, GX=GX,
            hyperparameters=hyperparameters,
            da=da
        )
        return model
    
    def generate_dataset(self, sem: SEM, da: DA, param: float=KAPPA):
        X, y = sem(N=self.n_samples, kappa=param)
        GX, _ = da(X)
        return X, y, GX

    def query_sweep(self, X):
        # Sweep over radial angle in PC1–PC2 plane
        queries = radial_sweep_pcs(X, self.sweep_samples)
        return queries

    def compute_result(self,
               method_name: str,
               method: Callable[[Optional[str]], Regressor],
               X, y, GX,
               query: np.ndarray,
               da: Optional[DA]=None) -> float:
        model = self.fit(
            method_name, method, X, y, GX, da=da,
            hyperparameters=self.hyperparameters
        )
        bounds = model.predict(query)
        return bounds

    def run_experiment(self):
        if self.seed >= 0:
            set_seed(self.seed)

        sem = SEM()
        da = DA(sem.W_XY, kernel_dim=self.kernel_dim)

        X, y, GX = self.generate_dataset(sem, da)
        query_values = self.query_sweep(X)
        
        bounds_dim = (self.sweep_samples, self.n_experiments, 2)
        results = {
            name: (np.zeros(bounds_dim[:-1]) if 'PI' not in name else np.zeros(bounds_dim))
            for name in self.methods
        }
        
        experiment_name = self.__class__.__name__
        pbar_experiment = MANAGER.counter(
            total=self.n_experiments, desc=f'{experiment_name}', unit='SEMs',
        )
        for j in range(self.n_experiments):
            sem_solution = sem.solution
            X, y, GX = self.generate_dataset(sem, da)
            
            pbar_sem = MANAGER.counter(
                total=self.sweep_samples, desc=f'SEM {j}', unit='queries', leave=False
            )
            for i, query in enumerate(query_values[:, np.newaxis, :]):
                pbar_methods = MANAGER.counter(
                    total=len(self.methods), desc=f'Query. {i}', unit='methods', leave=False
                )
                for method_name, method in self.methods.items():
                    if method_name == 'ATE':
                        results[method_name][i][j] = query @ sem_solution
                    else:
                        results[method_name][i][j] = self.compute_result(
                            method_name, method, X, y, GX, query, da=da
                        )
                    pbar_methods.update()
                pbar_methods.close()
                pbar_sem.update()
            pbar_sem.close()
            pbar_experiment.update()
        pbar_experiment.close()
        return np.linspace(0, 2*np.pi, len(query_values)), results


def make_panel_4x3(
        seed: int = 42,
        n_samples: int = 2_500,
        kernel_dim: int = 3,
        n_experiments: int = 10,
        sweep_samples: int = 21,
        methods: List[str] = ('ATE','ERM','DA+ERM','PI','DA+PI','INV+PI'),
        hyperparameters: Optional[Dict[str, Dict[str, float]]] = None,
        experiment: str = EXPERIMENT,
        legend_ncols: int = 2,           # tweak legend layout
    ):
    if seed >= 0:
        set_seed(seed)
    plt.rcParams.update(RC_PARAMS)

    # --- SEM / DA baseline ---
    sem = SEM()
    da = DA(sem.W_XY, kernel_dim=kernel_dim)
    X0, _ = sem(N=n_samples, kappa=KAPPA)
    GX0, _ = da(X0)

    # --- Sweeps (PC1, PC2 along ±3 std of their scores; radial: angles) ---
    pc1_pts, t1, mean1, pc1_vec = sweep_along_pc(X0, pc_index=0, n_steps=sweep_samples, std_range=3.0)
    pc2_pts, t2, mean2, pc2_vec = sweep_along_pc(X0, pc_index=1, n_steps=sweep_samples, std_range=3.0)
    radial_pts = radial_sweep_pcs(X0, n_points=sweep_samples)
    theta = np.linspace(0, 2*np.pi, sweep_samples, endpoint=False)

    # --- Methods and builders ---
    all_methods: Dict[str, ModelBuilder] = {
        'ATE': lambda: None,
        'ERM': lambda: ERM(),
        'DA+ERM': lambda: ERM(),
        'PI': lambda: PartialR2(gamma=GAMMA, gamma0=GAMMA0),
        'DA+PI': lambda: PartialR2(gamma=GAMMA, gamma0=GAMMA0),
        'INV+PI': lambda: invPartialR2(gamma=GAMMA, gamma0=GAMMA0, epsilon=EPSILON),
    }
    methods = [m for m in methods if m in all_methods]
    builders: Dict[str, ModelBuilder] = {m: all_methods[m] for m in methods}

    # --- Run a sweep for a set of query points ---
    def run_sweep(query_points: np.ndarray) -> Tuple[Dict[str, np.ndarray], np.ndarray]:
        S, E = query_points.shape[0], n_experiments
        res: Dict[str, np.ndarray] = {name: (np.zeros((S, E, 2)) if 'PI' in name else np.zeros((S, E)))
                                      for name in builders}
        gt = np.zeros(S)
        for j in range(E):
            Xj, yj = sem(N=n_samples, kappa=KAPPA)
            GXj, _ = da(Xj)
            for i in range(S):
                q = query_points[i][np.newaxis, :]
                gt[i] = (q @ sem.solution).item()
                for name, build in builders.items():
                    if name == 'ATE':
                        res[name][i, j] = (q @ sem.solution).item()
                        continue
                    model = build()
                    fit_model(model=model, name=name, X=Xj, y=yj, GX=GXj,
                              hyperparameters=hyperparameters, da=da)
                    out = np.asarray(model.predict(q)).squeeze()
                    if 'PI' in name:
                        res[name][i, j, 0] = out[0]
                        res[name][i, j, 1] = out[1]
                    else:
                        res[name][i, j] = float(out)
        return res, gt

    res1, gt1 = run_sweep(pc1_pts)
    res2, gt2 = run_sweep(pc2_pts)
    resR, gtR = run_sweep(radial_pts)

    x_pc1, x_pc2, x_rad = t1, t2, theta
    
    # === SWAP LOGIC: Reorder the lists so Radial is index 1 and PC2 is index 2 ===
    xs = [x_pc1, x_rad, x_pc2]
    packs = [(res1, gt1), (resR, gtR), (res2, gt2)]
    
    # --- Helpers for aggregations ---
    def mean_band(m):  # PI → (S,) low_mean, (S,) high_mean
        return m[:, :, 0].mean(axis=1), m[:, :, 1].mean(axis=1)
    def mean_line(m):  # point → (S,)
        return m.mean(axis=1)
    def width(m):      # PI width (point estimators -> zeros)
        return (m[:, :, 1] - m[:, :, 0]).mean(axis=1) if m.ndim == 3 else np.zeros(m.shape[0])

    def e_worst(m: np.ndarray, gt: np.ndarray, reduce: str = 'max') -> np.ndarray:
        """
        Worst-case squared error per sweep point i.
        """
        gt = gt[:, None]  # (S,1)

        if m.ndim == 2:
            se = (m - gt) ** 2                          # (S,E)
        elif m.ndim == 3 and m.shape[-1] == 2:
            lo, hi = m[:, :, 0], m[:, :, 1]
            se = np.maximum((lo - gt) ** 2, (hi - gt) ** 2)  # (S,E)
        else:
            raise ValueError("Unexpected shape for m.")

        if reduce == 'max':
            return se.max(axis=1)                       # (S,)
        elif isinstance(reduce, float) and 0.0 < reduce < 1.0:
            return np.quantile(se, q=reduce, axis=1)    # (S,)
        else:
            raise ValueError("reduce should be 'max' or a quantile in (0,1).")


    # --- Histogram projections (computed on full data; view is clipped by shared x) ---
    h_pc1_X = (X0 - mean1) @ pc1_vec
    h_pc1_G = (GX0 - mean1) @ pc1_vec
    h_pc2_X = (X0 - mean2) @ pc2_vec
    h_pc2_G = (GX0 - mean2) @ pc2_vec

    # --- Figure: 4 rows x 3 columns; bottom row is 2x each upper row ---
    fig, axes = plt.subplots(
        4, 3, figsize=(15, 8),
        sharex='col',
        gridspec_kw={'height_ratios': [0.2, 0.2, 0.2, 0.7]},
        constrained_layout=True
    )
    
    border_linewidth = plt.rcParams.get('axes.linewidth', 1.0)

    # --- Colors ---
    def mcolor(name):
        palette = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3','C4','C5'])
        return palette[color_map.get(name, 0) % len(palette)]

    palette = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3','C4','C5'])
    orig_color = palette[color_map.get('ERM', 0) % len(palette)]      # histogram: original data
    aug_color  = palette[color_map.get('DA+ERM', 3) % len(palette)]   # histogram: augmented data

    # --- Legend holder (will be placed at [2,1]) ---
    legend_handles: Dict[str, any] = {}
    
    # === SWAP LOGIC: Reorder titles and labels for the new layout ===
    column_titles = [
        r'principal direction 1' + '\n' + r'${\bm{x}} := t\cdot {\bm{u}}_1$',
        r'radial sweep' + '\n' + r'${\bm{x}} := {\sigma}_1\sin(\theta){\bm{u}}_1 + {\sigma}_2 \cos(\theta){\bm{u}}_2$',
        r'principal direction 2' + '\n' + r'${\bm{x}} := t\cdot {\bm{u}}_2$',
    ]
    bottom_ylabs = [
        r'${\bm{h}}^\top {\bm{x}}$',
    ]
    # sweep_labels = ['pc1', 'radial sweep', 'pc2']
    sweep_labels = ['','','']

    for col in range(3):
        res, gt = packs[col]
        xgrid = xs[col]

        # ---- Row 3 (bottom): predictions ----
        ax_pred = axes[3, col]
        for name in builders:
            label = TEX_MAPPER.get(name, name)
            if 'PI' in name:
                low, high = mean_band(res[name])
                if "INV" in name:
                    h = ax_pred.fill_between(xgrid, low, high, alpha=0.3, edgecolor='none', facecolor=mcolor(name))
                else:
                    h = ax_pred.fill_between(xgrid, low, high, alpha=0.2, edgecolor='none', facecolor=mcolor(name))
            else:
                y = mean_line(res[name])
                if name == 'ATE':
                    h = ax_pred.plot(xgrid, y, linestyle='--', linewidth=2, color='black')[0]
                else:
                    h = ax_pred.plot(xgrid, y, linewidth=2, color=mcolor(name))[0]
            if label not in legend_handles:
                legend_handles[label] = h

        # === SWAP LOGIC: Check if col is 1 (Radial) for theta label ===
        ax_pred.set_xlabel(r'$\theta$' if col == 1 else r'$t$', fontsize=FS_LABEL)
        
        if col == 0:
            ax_pred.set_ylabel(bottom_ylabs[col], fontsize=FS_LABEL)
        ax_pred.tick_params(labelsize=FS_TICK)
        ax_pred.set_xlim([xgrid.min(), xgrid.max()])

        # titles under bottom plots
        ax_pred.text(0.5, -0.25, sweep_labels[col], transform=ax_pred.transAxes,
                     ha='center', va='top', fontsize=FS_LABEL)

        # ---- Row 1: WIDTH (PI only non-zero) ----
        ax_w = axes[1, col]
        plotted_any = False
        for name in builders:
            if 'PI' not in name:
                continue
            plotted_any = True
            w = width(res[name])
            if 'INV' in name:
                ax_w.fill_between(xgrid, 0.0, w, alpha=0.3, edgecolor='none', facecolor=mcolor(name))    
            else:
                ax_w.fill_between(xgrid, 0.0, w, alpha=0.2, edgecolor='none', facecolor=mcolor(name))
            ax_w.plot(xgrid, w, linewidth=1.5, color=mcolor(name))
        if not plotted_any:
            ax_w.fill_between(xgrid, 0.0, 0.0, alpha=0.1, edgecolor='none', facecolor='0.8')
        if col == 0:
            ax_w.set_ylabel('width', fontsize=FS_LABEL)
        ax_w.tick_params(labelsize=FS_TICK)
        ax_w.set_ylim(0, None); ax_w.margins(y=0)

        # ---- Row 0 (top): E_worst (PI and DA+PI only) ----
        ax_ew = axes[0, col]
        methods_ew = ('PI', 'DA+PI', 'INV+PI')
        plotted_any = False
        for name in methods_ew:
            if name not in builders or name not in res:
                continue
            plotted_any = True
            ew = e_worst(res[name], gt, reduce='max')
            if 'INV' in name:
                ax_ew.fill_between(xgrid, 0.0, ew, alpha=0.3, edgecolor='none', facecolor=mcolor(name))    
            else:
                ax_ew.fill_between(xgrid, 0.0, ew, alpha=0.2, edgecolor='none', facecolor=mcolor(name))
            ax_ew.plot(xgrid, ew, linewidth=1.5, color=mcolor(name))

        if not plotted_any:
            ax_ew.fill_between(xgrid, 0.0, 0.0, alpha=0.1, edgecolor='none', facecolor='0.8')

        if col == 0:
            ax_ew.set_ylabel(r'$E_{\mathrm{worst}}^{\operatorname{do}({\bm{x}})}$', fontsize=FS_LABEL)
        ax_ew.tick_params(labelsize=FS_TICK)
        ax_ew.set_ylim(0, None); ax_ew.margins(y=0)
        
        # Add column title at top
        ax_ew.set_title(column_titles[col], fontsize=FS_LABEL, pad=8)

        # ---- Row 2: DENSITY (PC columns only) ----
        # === SWAP LOGIC: Col 0 is PC1, Col 2 is now PC2, Col 1 is empty/legend ===
        ax_hist = axes[2, col]
        if col == 0:
            ax_hist.hist(h_pc1_X, bins=40, density=True, alpha=0.45, color=orig_color)
            ax_hist.hist(h_pc1_G, bins=40, density=True, alpha=0.45, color=aug_color)
            ax_hist.set_ylabel('density', fontsize=FS_LABEL)
        elif col == 2:
            ax_hist.hist(h_pc2_X, bins=40, density=True, alpha=0.45, color=orig_color)
            ax_hist.hist(h_pc2_G, bins=40, density=True, alpha=0.45, color=aug_color)
        else:
            # This is col 1 (Radial), kept empty for Legend
            ax_hist.axis('off') 
        ax_hist.tick_params(labelsize=FS_TICK)
        ax_hist.set_ylim(0, None); ax_hist.margins(y=0)

    # ---- Single legend in position [2,1] (Radial column) ----
    # === SWAP LOGIC: Legend moved from [2,2] to [2,1] ===
    ax_legend = axes[2, 1]
    
    ax_legend.spines['top'].set_visible(True)
    ax_legend.spines['right'].set_visible(True)
    ax_legend.spines['bottom'].set_visible(True)
    ax_legend.spines['left'].set_visible(True)

    label_order = [TEX_MAPPER.get(name, name) for name in builders.keys()]
    handles = [legend_handles[lbl] for lbl in label_order if lbl in legend_handles]

    leg = ax_legend.legend(
        handles=handles,
        labels=label_order,
        loc='center',
        ncol=legend_ncols,
        fontsize=FS_TICK + 2,
        frameon=False,  # No legend frame since the axis provides the border
        borderpad=0.3475,
        labelspacing=0.5,
        handlelength=2.5,
        handletextpad=0.6,
        columnspacing=1.1,
        bbox_to_anchor=(0.5, 0.5),
        bbox_transform=ax_legend.transAxes,
    )
    
    for h in leg.legendHandles:
        try:
            h.set_linewidth(2.0)
        except Exception:
            pass

    fig.align_ylabels(axes[:, 0])    # aligns only the left-column y-labels

    # plt.tight_layout()
    save(fig, fname='panel_4x3', experiment=experiment, format=PLOT_FORMAT, dpi=PLOT_DPI)



def run(
        seed: int,
        n_samples: int,
        kernel_dim: int,
        n_experiments: int,
        sweep_samples: int,
        methods: List[str],
        augmentation: Optional[List[str]]=None,
        hyperparameters: Optional[Dict[str, Dict[str, float]]]=None,
        plot_panel: bool=False,
        panel_only: bool=False,
        **kwargs,
    ):
    # Panel-only mode
    if panel_only:
        make_panel_4x3(
            seed=seed,
            n_samples=n_samples,
            kernel_dim=kernel_dim,
            n_experiments=n_experiments,
            sweep_samples=sweep_samples,
            methods=methods,
            hyperparameters=hyperparameters,
            experiment=EXPERIMENT
        )
        return

    status = MANAGER.status_bar(
        status_format=u'Linear simulation{fill}Sweeping {sweep}{fill}{elapsed}',
        color='bold_underline_bright_white_on_lightslategray',
        justify=enlighten.Justify.CENTER, sweep='<parameter>',
        autorefresh=True, min_delta=0.5
    )

    all_methods: Dict[str, ModelBuilder] = {
        'ATE': lambda: None,
        'ERM': lambda: ERM(),
        'DA+ERM': lambda: ERM(),
        'PI': lambda: PartialR2(gamma=GAMMA, gamma0=GAMMA0),
        'DA+PI': lambda: PartialR2(gamma=GAMMA, gamma0=GAMMA0),
        'INV+PI': lambda: invPartialR2(gamma=GAMMA, gamma0=GAMMA0, epsilon=EPSILON),
    }
    methods: Dict[str, ModelBuilder] = {m: all_methods[m] for m in methods if m in all_methods}
    sweep_methods: Dict[str, ModelBuilder] = {
        m: all_methods[m] for m in methods if m in ('ERM', 'DA+ERM', 'ATE', 'PI', 'DA+PI', 'INV+PI')
    }
    
    # sweep over treatment queries
    status.update(sweep='treatment query')
    logger.info('Sweeping over treatment queries.')
    treatment_values, outcome_values = SweepExperiment(
        seed=seed,
        n_samples=n_samples,
        kernel_dim=kernel_dim,
        n_experiments=n_experiments,
        methods=sweep_methods,
        sweep_samples=sweep_samples,
        hyperparameters=hyperparameters
    ).run_experiment()
    save(obj=treatment_values, fname='treatment_values', experiment=EXPERIMENT, format='pkl')
    save(obj=outcome_values, fname='outcome_values', experiment=EXPERIMENT, format='pkl')
    sweep_plot(treatment_values, outcome_values, **ANNOTATE_SWEEP_PLOT['pc12'], experiment=EXPERIMENT)

    # Optionally build the panel after the sweep.
    if plot_panel:
        make_panel_4x3(
            seed=seed,
            n_samples=n_samples,
            kernel_dim=kernel_dim,
            n_experiments=n_experiments,
            sweep_samples=sweep_samples,
            methods=list(methods.keys()),
            hyperparameters=hyperparameters,
            experiment=EXPERIMENT
        )


if __name__ == '__main__':
    CLI = ArgumentParser(description='Linear simulation experiment.')
    CLI.add_argument('--seed', type=int, default=42, help='Random seed for the experiment. Negative is random.')
    CLI.add_argument('--n_samples', type=int, default=2_500, help='Number of samples per experiment.')
    CLI.add_argument('--n_experiments', type=int, default=10, help='Number of experiments.')
    CLI.add_argument('--sweep_samples', type=int, default=10, help='Sweep resolution across kappa, alpha and gamma.')
    CLI.add_argument(
        '--methods', nargs="*", type=str,
        default=['ERM', 'DA+ERM', 'DA+IVL-CV', 'DA+IV'],
        help='Methods to use. Specify in space-separated format -- `ERM DA+ERM DA+IVL-CV DA+IV`.'
    )
    CLI.add_argument('--plot-panel', action='store_true', help='After running the sweep, also generate the panel.')
    CLI.add_argument('--panel-only', action='store_true', help='Only generate the panel (skip sweep_plot).')
    args = CLI.parse_args()
    run(**vars(args))
