import torch as T
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

from src.scm.ncm.gan_ncm import GAN_NCM
from src.scm.scm import expand_do


def log(x):
    return T.log(x + 1e-8)


class RelationalGANPipeline(pl.LightningModule):
    """
    Multi-skeleton GAN training with shared generator parameters across skeletons.
    """

    patience = 500
    max_epochs = 3000

    def __init__(
        self,
        source_specs,
        target_spec,
        hyperparams=None,
        ncm_model=GAN_NCM,
        source_ncms=None,
        target_ncm=None,
        query=None,
        query_sign=-1,
    ):
        super().__init__()
        if hyperparams is None:
            hyperparams = dict()
        self.hyperparams = dict(hyperparams)

        self.hyperparams.setdefault("gan-mode", "vanilla")
        self.hyperparams.setdefault("gen-sigmoid", False)
        self.hyperparams.setdefault("perturb-sd", 0.1)
        self.hyperparams.setdefault("d-iters", 1)
        self.hyperparams.setdefault("grad-clamp", 0.01)
        self.hyperparams.setdefault("gp-weight", 10.0)
        self.hyperparams.setdefault("neural-pu", False)
        self.hyperparams.setdefault("layer-norm", False)
        self.hyperparams.setdefault("single-disc", False)

        self.source_specs = source_specs
        self.target_spec = target_spec
        self.query = query
        self.query_sign = query_sign

        self.gan_mode = self.hyperparams.get("gan-mode", "vanilla")
        self.gen_sigmoid = self.hyperparams.get("gen-sigmoid", False)
        self.perturb_sd = self.hyperparams.get("perturb-sd", 0.1)
        self.d_iters = self.hyperparams.get("d-iters", 1)
        self.grad_clamp = self.hyperparams.get("grad-clamp", 0.01)
        self.gp_weight = self.hyperparams.get("gp-weight", 10.0)
        self.lr = self.hyperparams.get("lr", 0.001)

        self.ncm_batch_size = self.hyperparams.get("ncm-bs", 1000)
        self.batch_size = self.hyperparams.get("data-bs", 1000)
        self.mc_sample_size = self.hyperparams.get("mc-sample-size", 10000)
        self.min_lambda = self.hyperparams.get("min-lambda", 0.001)
        self.max_lambda = self.hyperparams.get("max-lambda", 1.0)
        self.max_query_iters = self.hyperparams.get("max-query-iters", 3000)

        if source_ncms is None or target_ncm is None:
            if not self.hyperparams.get("allow-unshared", False):
                raise ValueError(
                    "RelationalGANPipeline requires shared parameters across skeletons. "
                    "Pass source_ncms and target_ncm that share modules, or set "
                    "hyperparams['allow-unshared']=True to build separate models."
                )
            source_ncms, target_ncm = self._build_ncms(ncm_model=ncm_model)

        self.source_ncms = T.nn.ModuleList(source_ncms)
        self.target_ncm = target_ncm

        self.pair_index = self._build_pairs()
        self.automatic_optimization = False

    def _build_ncms(self, ncm_model):
        source_ncms = []
        for spec in self.source_specs:
            cg = spec["cg"]
            v_size = spec.get("v_size", {k: 1 for k in cg})
            per_hp = dict(self.hyperparams)
            per_hp["do-var-list"] = spec["do_var_list"]
            ncm = ncm_model(
                cg,
                v_size=v_size,
                default_u_size=per_hp.get("u-size", 1),
                hyperparams=per_hp,
                gen_use_sigmoid=per_hp.get("gen-sigmoid", False),
                disc_use_sigmoid=(per_hp.get("gan-mode", "NA") != "wgan"),
            )
            source_ncms.append(ncm)

        target_cg = self.target_spec["cg"]
        target_v_size = self.target_spec.get("v_size", {k: 1 for k in target_cg})
        target_hp = dict(self.hyperparams)
        target_hp["do-var-list"] = self.target_spec.get("do_var_list", [{}])
        target_ncm = ncm_model(
            target_cg,
            v_size=target_v_size,
            default_u_size=target_hp.get("u-size", 1),
            hyperparams=target_hp,
            gen_use_sigmoid=target_hp.get("gen-sigmoid", False),
            disc_use_sigmoid=(target_hp.get("gan-mode", "NA") != "wgan"),
        )
        return source_ncms, target_ncm

    def _build_pairs(self):
        pairs = []
        for src_idx, spec in enumerate(self.source_specs):
            for do_idx in range(len(spec["do_var_list"])):
                pairs.append((src_idx, do_idx))
        return pairs

    def _unique_params(self, groups):
        params = []
        seen = set()
        for group in groups:
            for p in group:
                pid = id(p)
                if pid in seen:
                    continue
                seen.add(pid)
                params.append(p)
        return params

    def configure_optimizers(self):
        gen_groups = []
        disc_groups = []
        pu_groups = []
        for ncm in self.source_ncms:
            gen_groups.append(ncm.f.parameters())
            disc_groups.append(ncm.f_disc.parameters())
            pu_groups.append(ncm.pu.parameters())

        if self.gan_mode == "wgan":
            opt_gen = T.optim.RMSprop(self._unique_params(gen_groups), lr=self.lr)
            opt_disc = T.optim.RMSprop(self._unique_params(disc_groups), lr=self.lr)
            opt_pu = T.optim.RMSprop(self._unique_params(pu_groups), lr=self.lr)
        else:
            opt_gen = T.optim.Adam(self._unique_params(gen_groups), lr=self.lr)
            opt_disc = T.optim.Adam(self._unique_params(disc_groups), lr=self.lr)
            opt_pu = T.optim.Adam(self._unique_params(pu_groups), lr=self.lr)
        return opt_gen, opt_disc, opt_pu

    def train_dataloader(self):
        datasets = [spec["dat_sets"] for spec in self.source_specs]
        return DataLoader(
            RelationalGANSCMDataset(datasets, self.pair_index),
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )

    def _get_D_loss(self, real_out, fake_out):
        if self.gan_mode in {"wgan", "wgangp"}:
            return -(T.mean(real_out) - T.mean(fake_out))
        return -T.mean(log(real_out) + log(1 - fake_out))

    def _get_G_loss(self, fake_out):
        if self.gan_mode == "bgan":
            return 0.5 * T.mean((log(fake_out) - log(1 - fake_out)) ** 2)
        if self.gan_mode in {"wgan", "wgangp"}:
            return -T.mean(fake_out)
        return -T.mean(log(fake_out))

    def _get_gradient_penalty(self, ncm, real_data, fake_data, disc_index):
        n = real_data[next(iter(real_data))].shape[0]
        interpolated_data = dict()
        alpha = T.rand(n, 1, device=self.device, requires_grad=True)
        for V in real_data:
            v_alpha = alpha.expand_as(real_data[V])
            interpolated_data[V] = v_alpha * real_data[V].detach() + (1 - v_alpha) * fake_data[V].detach()

        interpolated_out, inp = ncm.get_disc_outputs(interpolated_data, disc_index, include_inp=True)
        gradients = T.autograd.grad(
            outputs=interpolated_out,
            inputs=inp,
            grad_outputs=T.ones(interpolated_out.size(), device=self.device),
            create_graph=True,
            retain_graph=True,
        )[0]
        gradients = gradients.view(n, -1)
        gradients_norm = T.sqrt(T.sum(gradients ** 2, dim=1) + 1e-12)
        return self.gp_weight * (T.relu(gradients_norm - self.grad_clamp) ** 2).mean()

    def _get_q_loss(self):
        if self.query is None:
            return None
        if isinstance(self.query, (list, tuple, set)):
            query_loss = 0.0
            for query in self.query:
                query_loss += self.target_ncm.compute_ctf(query, n=self.mc_sample_size)
            return query_loss
        return self.target_ncm.compute_ctf(self.query, n=self.mc_sample_size)

    def training_step(self, batch, batch_idx):
        G_opt, D_opt, PU_opt = self.optimizers()

        reg_ratio = min(self.current_epoch, self.max_query_iters) / self.max_query_iters
        reg_up = T.log(T.tensor(self.max_lambda))
        reg_low = T.log(T.tensor(self.min_lambda))
        max_reg = T.exp(reg_up - reg_ratio * (reg_up - reg_low))

        G_opt.zero_grad()
        PU_opt.zero_grad()

        total_d_loss = 0.0
        for d_iter in range(self.d_iters):
            D_opt.zero_grad()
            for pair_idx, (src_idx, do_idx) in enumerate(self.pair_index):
                ncm = self.source_ncms[src_idx]
                do_set = self.source_specs[src_idx]["do_var_list"][do_idx]
                ncm_batch = ncm(
                    self.ncm_batch_size,
                    do={k: expand_do(v, self.ncm_batch_size) for (k, v) in do_set.items()},
                )
                pair_batch = batch[pair_idx]
                batch_n = pair_batch[next(iter(pair_batch))].shape[0]
                cut = batch_n // self.d_iters
                if cut == 0:
                    continue
                real_batch = {
                    k: v[d_iter * cut:(d_iter + 1) * cut].float()
                    for (k, v) in pair_batch.items()
                }
                if not self.gen_sigmoid:
                    new_real_batch = dict()
                    for k in real_batch:
                        v = real_batch[k]
                        if k not in do_set:
                            new_real_batch[k] = T.normal(
                                mean=v, std=self.perturb_sd * T.ones(v.shape, device=self.device)
                            )
                        else:
                            new_real_batch[k] = v
                    real_batch = new_real_batch

                ncm_disc_real_out = ncm.get_disc_outputs(real_batch, do_idx)
                ncm_disc_fake_out = ncm.get_disc_outputs(ncm_batch, do_idx)
                D_loss = self._get_D_loss(ncm_disc_real_out, ncm_disc_fake_out)

                if self.gan_mode == "wgangp":
                    grad_penalty = self._get_gradient_penalty(ncm, real_batch, ncm_batch, do_idx)
                    self.log("grad_penalty", grad_penalty, prog_bar=True)
                    D_loss += grad_penalty

                total_d_loss += D_loss.item()
                self.manual_backward(D_loss)

            D_opt.step()

            if self.gan_mode == "wgan":
                for ncm in self.source_ncms:
                    for p in ncm.f_disc.parameters():
                        p.data.clamp_(-self.grad_clamp, self.grad_clamp)

            for ncm in self.source_ncms:
                ncm.f.zero_grad()
                ncm.f_disc.zero_grad()
                ncm.pu.zero_grad()

        g_loss_record = 0.0
        for pair_idx, (src_idx, do_idx) in enumerate(self.pair_index):
            ncm = self.source_ncms[src_idx]
            do_set = self.source_specs[src_idx]["do_var_list"][do_idx]
            ncm_batch = ncm(
                self.ncm_batch_size,
                do={k: expand_do(v, self.ncm_batch_size) for (k, v) in do_set.items()},
            )
            ncm_disc_fake_out = ncm.get_disc_outputs(ncm_batch, do_idx)
            G_loss = self._get_G_loss(ncm_disc_fake_out) / len(self.pair_index)
            g_loss_record += G_loss.item()
            self.manual_backward(G_loss)

        q_loss_record = 0.0
        q_loss = self._get_q_loss()
        if q_loss is not None:
            q_loss = max_reg * (len(self.pair_index) ** 2) * self.query_sign * q_loss
            q_loss_record = q_loss.item()
            if not T.isnan(q_loss):
                self.manual_backward(q_loss)

        G_opt.step()
        PU_opt.step()

        for ncm in self.source_ncms:
            ncm.f.zero_grad()
            ncm.f_disc.zero_grad()
            ncm.pu.zero_grad()

        self.log("G_loss", g_loss_record, prog_bar=True)
        self.log("D_loss", total_d_loss, prog_bar=True)
        self.log("Q_loss", q_loss_record, prog_bar=True)


class RelationalGANSCMDataset(Dataset):
    def __init__(self, source_dat_sets, pair_index):
        self.source_dat_sets = source_dat_sets
        self.pair_index = pair_index
        lengths = []
        for src_idx, do_idx in pair_index:
            dat = source_dat_sets[src_idx][do_idx]
            lengths.append(len(dat[next(iter(dat))]))
        self.length = min(lengths) if lengths else 0

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        batch = []
        for src_idx, do_idx in self.pair_index:
            dat = self.source_dat_sets[src_idx][do_idx]
            batch.append({k: dat[k][idx] for k in dat})
        return batch
