import torch as T
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

from src.scm.ncm.mle_ncm import MLE_NCM
from src.ds.counterfactual import CTF


class RelationalIDPipeline(pl.LightningModule):
    """
    Relational ID training across multiple source skeletons with a shared-parameter NCM.

    The caller must ensure parameter sharing across per-skeleton NCMs. The simplest
    pattern is to build NCMs with shared modules (e.g., shared ModuleDict entries
    or wrappers) and pass them in via `source_ncms` and `target_ncm`.
    """

    patience = 400

    def __init__(
        self,
        source_specs,
        target_spec,
        hyperparams=None,
        ncm_model=MLE_NCM,
        source_ncms=None,
        target_ncm=None,
        query=None,
        query_sign=-1,
        role_vars=None,
    ):
        """
        Args:
            source_specs: list of dicts with keys:
                - "cg": CausalGraph for the source skeleton
                - "dat_sets": list of data dicts, one per do-set
                - "do_var_list": list of do-set dicts
            target_spec: dict with keys:
                - "cg": CausalGraph for the target skeleton
            hyperparams: dict of training hyperparameters
            ncm_model: NCM class to construct when source_ncms/target_ncm not provided
            source_ncms: list of NCMs (one per source skeleton), already sharing params
            target_ncm: NCM for target skeleton, sharing params with sources
            query: CTF query to min/max on the target skeleton
            query_sign: -1 to maximize query, +1 to minimize query
        """
        super().__init__()
        if hyperparams is None:
            hyperparams = dict()

        self.source_specs = source_specs
        self.target_spec = target_spec
        self.query = query
        self.query_sign = query_sign
        self.role_vars = set(role_vars or [])

        self.mc_sample_size = hyperparams.get("mc-sample-size", 10000)
        self.ncm_batch_size = hyperparams.get("ncm-bs", 1000)
        self.lr = hyperparams.get("lr", 4e-3)
        self.min_lambda = hyperparams.get("min-lambda", 0.001)
        self.max_lambda = hyperparams.get("max-lambda", 1.0)
        self.max_query_iters = hyperparams.get("max-query-iters", 3000)
        self.full_batch = hyperparams.get("full-batch", False)
        self.fast_counts = hyperparams.get("fast-counts", False)
        self.fast_counts_n = int(hyperparams.get("fast-counts-n", 100))
        self.profile_counts = hyperparams.get("profile-counts", False)
        self.profile_likelihood = hyperparams.get("profile-likelihood", False)
        self.profile_every = int(hyperparams.get("profile-every", 1))
        self._profile_step = 0

        if source_ncms is None or target_ncm is None:
            if not hyperparams.get("allow-unshared", False):
                raise ValueError(
                    "RelationalIDPipeline requires shared parameters across skeletons. "
                    "Pass source_ncms and target_ncm that share modules, or set "
                    "hyperparams['allow-unshared']=True to build separate models."
                )
            source_ncms, target_ncm = self._build_ncms(
                ncm_model=ncm_model, hyperparams=hyperparams
            )

        self.source_ncms = T.nn.ModuleList(source_ncms)
        self.target_ncm = target_ncm

        self.data_counts = None
        self.pair_index = self._build_pairs()
        if self.full_batch:
            self.data_counts = []
            for pair_idx, (src_idx, do_idx) in enumerate(self.pair_index):
                dat_sets = self.source_specs[src_idx]["dat_sets"]
                dat = dat_sets[do_idx]
                n = dat[next(iter(dat))].shape[0]
                if self.fast_counts:
                    n = min(n, self.fast_counts_n)
                self.data_counts.append(self._get_data_counts(dat, n))

        self.automatic_optimization = False

    def _build_ncms(self, ncm_model, hyperparams):
        source_ncms = []
        for spec in self.source_specs:
            cg = spec["cg"]
            v_size = {k: 1 if k in {"X", "Y", "M", "W"} else 1 for k in cg}
            ncm = ncm_model(
                cg,
                v_size=v_size,
                default_u_size=hyperparams.get("u-size", 1),
                hyperparams=hyperparams,
            )
            source_ncms.append(ncm)

        target_cg = self.target_spec["cg"]
        v_size = {k: 1 if k in {"X", "Y", "M", "W"} else 1 for k in target_cg}
        target_ncm = ncm_model(
            target_cg,
            v_size=v_size,
            default_u_size=hyperparams.get("u-size", 1),
            hyperparams=hyperparams,
        )
        return source_ncms, target_ncm

    def _build_pairs(self):
        pairs = []
        for src_idx, spec in enumerate(self.source_specs):
            for do_idx in range(len(spec["do_var_list"])):
                pairs.append((src_idx, do_idx))
        return pairs

    def _get_data_counts(self, data, n):
        if self.profile_counts:
            return self._get_data_counts_profiled(data, n)
        counts = dict()
        for i in range(n):
            data_point = {k: tuple(v[i].cpu().tolist()) for (k, v) in data.items()}
            point_key = frozenset(data_point.items())
            counts[point_key] = counts.get(point_key, 0) + 1
        return counts

    def _get_data_counts_profiled(self, data, n):
        import time
        start = time.perf_counter()
        counts = dict()
        for i in range(n):
            data_point = {k: tuple(v[i].cpu().tolist()) for (k, v) in data.items()}
            point_key = frozenset(data_point.items())
            counts[point_key] = counts.get(point_key, 0) + 1
        duration = time.perf_counter() - start
        print(f"[profile] _get_data_counts n={n} took {duration:.4f}s")
        return counts

    def _unique_params(self):
        params = []
        seen = set()
        for ncm in list(self.source_ncms) + [self.target_ncm]:
            for p in ncm.parameters():
                pid = id(p)
                if pid in seen:
                    continue
                seen.add(pid)
                params.append(p)
        return params

    def configure_optimizers(self):
        optim = T.optim.AdamW(self._unique_params(), lr=self.lr)
        return {
            "optimizer": optim,
            "lr_scheduler": T.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optim, 50, 1, eta_min=1e-4
            ),
        }

    def train_dataloader(self):
        datasets = [spec["dat_sets"] for spec in self.source_specs]
        return DataLoader(
            RelationalSCMDataset(datasets, self.pair_index),
            batch_size=self.ncm_batch_size,
            shuffle=True,
            drop_last=True,
        )

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()

        reg_ratio = min(self.current_epoch, self.max_query_iters) / self.max_query_iters
        reg_up = T.log(T.tensor(self.max_lambda))
        reg_low = T.log(T.tensor(self.min_lambda))
        max_reg = T.exp(reg_up - reg_ratio * (reg_up - reg_low))

        opt.zero_grad()
        loss = 0.0
        total_count = 0
        self._profile_step += 1
        do_profile = self._profile_step % self.profile_every == 0
        for pair_idx, (src_idx, do_idx) in enumerate(self.pair_index):
            do_set = self.source_specs[src_idx]["do_var_list"][do_idx]
            ncm = self.source_ncms[src_idx]
            role_vars = self.role_vars.intersection(ncm.v)
            do_set_vars = set(do_set.keys()).union(role_vars)
            skip_vars = do_set_vars.union(role_vars)
            if self.full_batch:
                data_counts = self.data_counts[pair_idx]
            else:
                data_n = batch[pair_idx][next(iter(batch[pair_idx]))].shape[0]
                if self.fast_counts:
                    data_n = min(data_n, self.fast_counts_n)
                data_counts = self._get_data_counts(batch[pair_idx], data_n)

            for point, count in data_counts.items():
                data_point = {k: T.ByteTensor(v).to(device=self.device) for (k, v) in point}
                if self.profile_likelihood and do_profile:
                    import time
                    start = time.perf_counter()
                    log_pv = ncm.likelihood(
                        data_point, skip=skip_vars, mc_size=self.mc_sample_size
                    )
                    duration = time.perf_counter() - start
                    print(
                        f"[profile] likelihood mc={self.mc_sample_size} took {duration:.4f}s"
                    )
                else:
                    log_pv = ncm.likelihood(
                        data_point, skip=skip_vars, mc_size=self.mc_sample_size
                    )
                loss -= count * log_pv
                total_count += count

        if total_count > 0:
            loss = loss / total_count
        loss_record = loss.item() if hasattr(loss, "item") else float(loss)
        self.manual_backward(loss)

        q_loss_record = 0.0
        if self.query is not None:
            if isinstance(self.query, (list, tuple, set)):
                q_loss = 0.0
                for q in self.query:
                    q_loss += self.target_ncm.compute_ctf(q, n=self.mc_sample_size)
            else:
                q_loss = self.target_ncm.compute_ctf(self.query, n=self.mc_sample_size)
            q_loss = max_reg * self.query_sign * q_loss
            q_loss_record = q_loss.item() if hasattr(q_loss, "item") else float(q_loss)
            self.manual_backward(q_loss)

        opt.step()

        self.log("train_loss", loss_record, prog_bar=True)
        self.log("P_loss", loss_record, prog_bar=True)
        self.log("Q_loss", q_loss_record, prog_bar=True)


class RelationalSCMDataset(Dataset):
    def __init__(self, source_dat_sets, pair_index):
        self.source_dat_sets = source_dat_sets
        self.pair_index = pair_index
        lengths = []
        for src_idx, do_idx in pair_index:
            dat = source_dat_sets[src_idx][do_idx]
            lengths.append(len(dat[next(iter(dat))]))
        self.length = min(lengths) if lengths else 0

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        batch = []
        for src_idx, do_idx in self.pair_index:
            dat = self.source_dat_sets[src_idx][do_idx]
            batch.append({k: dat[k][idx] for k in dat})
        return batch
