import itertools

import torch as T

import src.metric.divergences as dvg
from src.scm.ncm.feedforward_ncm import FF_NCM

from .base_pipeline import BasePipeline


def _is_light_color(name):
    return name.startswith("L") and name.endswith("_C")


def _is_sprite_color(name):
    return name.startswith("S") and name.endswith("_C")


def _is_sprite_shape(name):
    return name.startswith("S") and name.endswith("_S")


def _sprite_base(name):
    return name.split("_")[0]


class SpritesRelationalPipeline(BasePipeline):
    """
    Two-stage training:
      1) Train light color modules on light-only data.
      2) Train sprite modules on (agg_light, sprite_color, sprite_shape) tables.
    """

    patience = 50

    def __init__(
        self,
        generator,
        do_var_list,
        dat_sets,
        cg,
        dim,
        hyperparams=None,
        ncm_model=FF_NCM,
        ncm=None,
        sprite_to_lights=None,
    ):
        if hyperparams is None:
            hyperparams = dict()

        v_size = {v: dim if v not in ("X", "Y") else 1 for v in cg.v}
        if not ncm:
            ncm = ncm_model(
                cg,
                v_size=v_size,
                default_u_size=hyperparams.get("u-size", 1),
                hyperparams=hyperparams,
            )
        super().__init__(
            generator,
            do_var_list,
            dat_sets,
            cg,
            dim,
            ncm,
            batch_size=hyperparams.get("data-bs", 1000),
        )

        self.ncm_batch_size = hyperparams.get("ncm-bs", 1000)
        self.lr = hyperparams.get("lr", 1e-3)
        self.light_train_epochs = hyperparams.get("light-epochs", 50)
        self.ordered_v = cg.v
        self.sprite_to_lights = sprite_to_lights or self._default_sprite_to_lights()

        self.automatic_optimization = False

    def _default_sprite_to_lights(self):
        sprite_to_lights = {}
        for v in self.cg.v:
            if not _is_sprite_color(v):
                continue
            base = _sprite_base(v)
            parents = list(self.cg.pa.get(v, []))
            lights = [p for p in parents if _is_light_color(p)]
            sprite_to_lights[base] = lights
        return sprite_to_lights

    def _light_vars(self):
        return [v for v in self.cg.v if _is_light_color(v)]

    def _sprite_bases(self, batch):
        bases = set()
        for v in batch:
            if _is_sprite_color(v) or _is_sprite_shape(v):
                bases.add(_sprite_base(v))
        return sorted(bases)

    def _aggregate_lights(self, batch, base):
        light_vars = self.sprite_to_lights.get(base, [])
        if not light_vars:
            return 0
        light_stack = T.stack([batch[v].float() for v in light_vars], dim=0)
        return T.mode(light_stack, dim=0).values

    def _ncm_param_groups(self, vars_for_group):
        params = []
        seen = set()
        for v in vars_for_group:
            if v not in self.ncm.f:
                continue
            for p in self.ncm.f[v].parameters():
                pid = id(p)
                if pid in seen:
                    continue
                seen.add(pid)
                params.append(p)
        return params

    def configure_optimizers(self):
        light_vars = self._light_vars()
        sprite_vars = [v for v in self.cg.v if _is_sprite_color(v) or _is_sprite_shape(v)]

        light_params = self._ncm_param_groups(light_vars)
        sprite_params = self._ncm_param_groups(sprite_vars)

        light_optim = T.optim.AdamW(light_params, lr=self.lr)
        sprite_optim = T.optim.AdamW(sprite_params, lr=self.lr)
        return [light_optim, sprite_optim]

    def _predict_sprite_var(self, var, agg_light, batch, u_sample):
        module = self.ncm.f[var]
        pa = {}
        for parent in self.cg.pa.get(var, []):
            if _is_light_color(parent):
                pa[parent] = agg_light
            else:
                pa[parent] = batch[parent]

        u_keys = list(self.cg.v2c2.get(var, []))
        u_for_var = {k: u_sample[k] for k in u_keys}
        return module(pa, u_for_var)

    def training_step(self, batch, batch_idx):
        obs_batch = batch[0]
        if self.current_epoch < self.light_train_epochs:
            opt = self.optimizers()[0]
            ncm_n = self.ncm_batch_size
            ncm_batch = self.ncm(ncm_n)

            light_vars = self._light_vars()
            dat_mat = T.cat([obs_batch[k] for k in light_vars], dim=1)
            ncm_mat = T.cat([ncm_batch[k] for k in light_vars], dim=1)

            opt.zero_grad()
            loss = dvg.MMD_loss(dat_mat.float(), ncm_mat, gamma=1)
            self.manual_backward(loss)
            opt.step()

            self.log("train_loss", loss.item(), prog_bar=True)
            self.log("stage", 0, prog_bar=True)
            return

        opt = self.optimizers()[1]
        sprite_bases = self._sprite_bases(obs_batch)
        if not sprite_bases:
            return

        u_sample = self.ncm.pu.sample(obs_batch[next(iter(obs_batch))].shape[0])
        data_rows = []
        model_rows = []
        for base in sprite_bases:
            agg_light = self._aggregate_lights(obs_batch, base)
            if agg_light is None:
                continue
            color_var = f"{base}_C"
            shape_var = f"{base}_S"
            if color_var not in obs_batch or shape_var not in obs_batch:
                continue

            color_pred = self._predict_sprite_var(color_var, agg_light, obs_batch, u_sample)
            shape_pred = self._predict_sprite_var(shape_var, agg_light, obs_batch, u_sample)

            data_rows.append(T.cat([agg_light, obs_batch[color_var], obs_batch[shape_var]], dim=1))
            model_rows.append(T.cat([agg_light, color_pred, shape_pred], dim=1))

        if not data_rows:
            return

        dat_mat = T.cat(data_rows, dim=0)
        ncm_mat = T.cat(model_rows, dim=0)

        opt.zero_grad()
        loss = dvg.MMD_loss(dat_mat.float(), ncm_mat, gamma=1)
        self.manual_backward(loss)
        opt.step()

        self.log("train_loss", loss.item(), prog_bar=True)
        self.log("stage", 1, prog_bar=True)
