import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from models.transformer_model import GraphTransformer

from metrics.train_metrics import TrainLossDiscrete

from src import utils
from src.flow_matching.ref_dist import RefDistribution
from src.flow_matching.time_distorter import TimeDistorter

from src.flow_matching.bw_flow import MeanMetric, BWVelocity
from src.flow_matching.linear_flow import LinearVelocity


class DiscreteBWFlow(pl.LightningModule):
    def __init__(
        self,
        cfg,
        dataset_infos,
        train_metrics,
        sampling_metrics,
        extra_features,
        domain_features,
    ):
        super().__init__()

        self.cfg = cfg
        self.sample_type = cfg.sample.sample_method
        self.name = f"{cfg.dataset.name}_{cfg.general.name}"
        self.model_dtype = torch.float32

        self.T = cfg.model.diffusion_steps
        self.sample_T = (
            cfg.sample.sample_steps if cfg.general.test_only is not None else self.T
        )
        self.eta = self.cfg.sample.eta
        self.omega = self.cfg.sample.omega

        # Dataset info
        self.input_dims = dataset_infos.input_dims
        self.output_dims = dataset_infos.output_dims
        self.dataset_info = dataset_infos
        self.node_dist = dataset_infos.nodes_dist

        print("max num nodes", len(self.node_dist.prob) - 1)
        print("min num nodes", torch.where(self.node_dist.prob > 0)[0][0])

        self.train_metrics = train_metrics
        self.sampling_metrics = sampling_metrics

        self.extra_features = extra_features
        self.domain_features = domain_features

        self.ref_dist = RefDistribution(cfg.model.transition, dataset_infos)
        self.limit_dist = self.ref_dist.get_limit_dist()
        self.ref_dist.update_input_output_dims(self.input_dims)
        self.ref_dist.update_dataset_infos(self.dataset_info)

        # BW flow components
        self.mean_fn = MeanMetric(cfg)
        self.BW_velocity = BWVelocity(cfg)
        self.linear_velocity=LinearVelocity(cfg, self.limit_dist, self.device)

        # Loss
        self.train_loss = TrainLossDiscrete(
            self.cfg.model.lambda_train,
            self.cfg.model.label_smoothing,
            utils.PlaceHolder(X=None, E=None, y=None)
        )

        # Model
        self.model = GraphTransformer(
            n_layers=cfg.model.n_layers,
            input_dims=self.input_dims,
            hidden_mlp_dims=self.cfg.model.hidden_mlp_dims,
            hidden_dims=self.cfg.model.hidden_dims,
            output_dims=self.output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU(),
        )

        self.save_hyperparameters(
            ignore=[
                "train_metrics",
                "sampling_metrics",
            ],
        )

        self.train_iterations = None
        self.log_every_steps = cfg.general.log_every_steps

        self.time_distorter = TimeDistorter(
            train_distortion=cfg.train.time_distortion,
            sample_distortion=cfg.sample.time_distortion,
            s=cfg.train.mode_s,
            alpha=1,
            beta=1,
        )

    # -------------------------------------------------------------------------
    # Lightning hooks
    # -------------------------------------------------------------------------

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.parameters(),
            lr=self.cfg.train.lr,
            amsgrad=True,
            weight_decay=self.cfg.train.weight_decay,
        )

    def training_step(self, data, i):
        if data.edge_index.numel() == 0:
            self.print("Found a batch with no edges. Skipping.")
            return

        dense_data, node_mask = utils.to_dense(
            data.x, data.edge_index, data.edge_attr, data.batch
        )
        dense_data = dense_data.mask(node_mask)
        X, E = dense_data.X, dense_data.E

        noisy_data = self.apply_noise(X, E, data.y, node_mask)
        extra_data = self.compute_extra_data(noisy_data)
        pred = self.forward(noisy_data, extra_data, node_mask)
        log_now = i % self.log_every_steps == 0

        loss = self.train_loss(
            masked_pred_X=pred.X,
            masked_pred_E=pred.E,
            pred_y=pred.y,
            true_X=X,
            true_E=E,
            true_y=data.y,
            weight=None
        )

        self.train_metrics(
            masked_pred_X=pred.X,
            masked_pred_E=pred.E,
            true_X=X,
            true_E=E,
            log=log_now,
        )
        return {"loss": loss}

    def on_fit_start(self) -> None:
        self.train_iterations = len(self.trainer.datamodule.train_dataloader())
        self.print(
            "Size of the input features",
            self.input_dims["X"],
            self.input_dims["E"],
            self.input_dims["y"],
        )

    def on_train_epoch_end(self) -> None:
        to_log = self.train_loss.log_epoch_metrics()
        self.print(
            f"Epoch {self.current_epoch}: "
            f"X_CE: {to_log['train_epoch/x_CE'] :.3f} "
            f"-- E_CE: {to_log['train_epoch/E_CE'] :.3f} "
            f"-- y_CE: {to_log['train_epoch/y_CE'] :.3f} "
        )
        epoch_at_metrics, epoch_bond_metrics = self.train_metrics.log_epoch_metrics()
        self.print(
            f"Epoch {self.current_epoch}: {epoch_at_metrics} -- {epoch_bond_metrics}"
        )

    def test_step(self, data, i):
        _ = self.sample_and_evaluate(test=True)

    # -------------------------------------------------------------------------
    # Sampling + evaluation
    # -------------------------------------------------------------------------

    def sample_and_evaluate(
        self,
        test=False,
        samples=None,
        sample_only=False,
    ):

        samples_to_generate = (
            self.cfg.general.final_model_samples_to_generate
            * self.cfg.general.num_sample_fold
        )
        samples_left_to_generate = samples_to_generate
        samples_left_to_save = self.cfg.general.final_model_samples_to_save
        chains_left_to_save = self.cfg.general.final_model_chains_to_save

        if samples is None:
            samples = []
            labels = []
            current_id = 0

            while samples_left_to_generate > 0:
                self.print(
                    f"Samples left to generate: {samples_left_to_generate}/"
                    f"{samples_to_generate}",
                    end="",
                    flush=True,
                )

                bs = 2 * self.cfg.train.batch_size
                to_generate = min(samples_left_to_generate, bs)
                to_save = min(samples_left_to_save, bs)
                chains_save = min(chains_left_to_save, bs)

                cur_samples, cur_labels = self.sample_batch(
                    to_generate,
                    num_nodes=None,
                )

                samples.extend(cur_samples)
                labels.extend(cur_labels)

                current_id += to_generate
                samples_left_to_save -= to_save
                samples_left_to_generate -= to_generate
                chains_left_to_save -= chains_save

            if sample_only:
                return None

        to_log = self.sampling_metrics.forward(
            samples,
            ref_metrics=self.dataset_info.ref_metrics,
            name=self.name,
            current_epoch=self.current_epoch,
            val_counter=-1,
            test=test,
            local_rank=self.local_rank,
            labels=None,
        )

        return to_log

    # -------------------------------------------------------------------------
    # Diffusion / flow steps
    # -------------------------------------------------------------------------

    def apply_noise(self, X, E, y, node_mask, t=None):
        """Sample noise and apply it to the data."""
        bs, n, _ = X.shape

        # Sample a timestep t
        if t is None:
            t_float = self.time_distorter.train_ft(bs, self.device)
        else:
            t_float = t

        t_int = torch.clamp((t_float * self.T).long().float() + 1, 1, self.T)

        sampled_0 = utils.sample_discrete_feature_noise(
            limit_dist=self.limit_dist,
            node_mask=node_mask,
        )

        A0_batch = 1 - sampled_0.E[..., 0]  # (bs, N, N), binary: 0/1
        A1_batch = 1 - E[..., 0]            # (bs, N, N), binary: 0/1

        if self.cfg.dataset.binary_edge_type:
            prob_X_t, _ = self.linear_velocity.p_xt_g_x1(
                X1=torch.argmax(X, dim=-1),
                E1=torch.argmax(E, dim=-1),
                t=t_float,
            )
            prob_E_t = self.mean_fn.mean_fn_batched(
                A0_batch, A1_batch, t_float, ret_prob=True
            )
            prob_E_t = prob_E_t if self.cfg.model.mix_rate else prob_E_t.int().float()
            prob_E_t = torch.stack((1 - prob_E_t, prob_E_t), dim=-1)
        else:
            prob_X_t, prob_E_t = self.linear_velocity.p_xt_g_x1(
                X1=torch.argmax(X, dim=-1),
                E1=torch.argmax(E, dim=-1),
                t=t_float,
            )
            prob_E_bw = self.mean_fn.mean_fn_batched(
                A0_batch, A1_batch, t_float, ret_prob=True
            )
            prob_E_bw = (
                prob_E_bw if self.cfg.model.mix_rate else prob_E_bw.int().float()
            )

            # Sample E
            E_t_bw = torch.bernoulli(prob_E_bw)  # (bs, n, n)
            E_t_bw = torch.triu(E_t_bw, diagonal=1)
            E_t_bw = E_t_bw + torch.transpose(E_t_bw, 1, 2)

            prob_E_t_no_edge = prob_E_t[..., 0] * (1 - E_t_bw)
            prob_E_t = prob_E_t[..., 1:] * (1 - prob_E_t_no_edge).unsqueeze(-1)
            prob_E_t = torch.cat(
                (prob_E_t_no_edge.unsqueeze(-1), prob_E_t),
                dim=-1,
            )

        sampled_t = utils.sample_discrete_features(
            probX=prob_X_t,
            probE=prob_E_t,
            node_mask=node_mask,
        )

        ref_dims = self.ref_dist.get_ref_dims()
        X_t = F.one_hot(sampled_t.X, num_classes=ref_dims["X"])
        E_t = F.one_hot(sampled_t.E, num_classes=ref_dims["E"])

        z_t = (
            utils.PlaceHolder(X=X_t, E=E_t, y=y)
            .type_as(X_t)
            .mask(node_mask)
        )

        noisy_data = {
            "t_int": t_int,
            "t": t_float,
            "X_t": z_t.X,
            "E_t": z_t.E,
            "y_t": z_t.y,
            "node_mask": node_mask,
        }

        return noisy_data

    def forward(self, noisy_data, extra_data, node_mask):
        X_in = torch.cat((noisy_data["X_t"], extra_data.X), dim=2).float()
        E_in = torch.cat((noisy_data["E_t"], extra_data.E), dim=3).float()
        y_in = torch.hstack((noisy_data["y_t"], extra_data.y)).float()
        return self.model(X_in, E_in, y_in, node_mask)

    @torch.no_grad()
    def sample_batch(
        self,
        batch_size: int,
        num_nodes=None,
    ):
        if num_nodes is None:
            n_nodes = self.node_dist.sample_n(batch_size, self.device)
        elif isinstance(num_nodes, int):
            n_nodes = num_nodes * torch.ones(
                batch_size,
                device=self.device,
                dtype=torch.int,
            )
        else:
            assert isinstance(num_nodes, torch.Tensor)
            n_nodes = num_nodes

        n_max = torch.max(n_nodes).item()
        arange = (
            torch.arange(n_max, device=self.device)
            .unsqueeze(0)
            .expand(batch_size, -1)
        )
        node_mask = arange < n_nodes.unsqueeze(1)

        # Sample initial noise
        z_T = utils.sample_discrete_feature_noise(
            limit_dist=self.ref_dist.get_limit_dist(),
            node_mask=node_mask,
        )
        X, E, y = z_T.X, z_T.E, z_T.y
        self.initial_sample = z_T

        assert (E == torch.transpose(E, 1, 2)).all()

        self.sample_T = self.cfg.sample.sample_steps

        for t_int in range(0, self.sample_T):
            t_array = t_int * torch.ones((batch_size, 1)).type_as(y)
            t_norm = t_array / self.sample_T

            s_array = t_array + 1
            s_norm = s_array / self.sample_T


            t_norm = self.time_distorter.apply_distortion(
                t_norm,
                self.cfg.sample.time_distortion,
            )
            s_norm = self.time_distorter.apply_distortion(
                s_norm,
                self.cfg.sample.time_distortion,
            )

            sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
                t_norm, s_norm,
                X, E, y, node_mask
            )

            X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        sampled_s = sampled_s.mask(node_mask, collapse=True)
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

        molecule_list = []
        label_list = []

        for i in range(batch_size):
            n = n_nodes[i]
            atom_types = X[i, :n].cpu()
            edge_types = E[i, :n, :n].cpu()
            molecule_list.append([atom_types, edge_types])
            label_list.append(y[i].cpu())

        return molecule_list, label_list

    def sample_p_zs_given_zt(
        self,
        t, s,
        X_t, E_t, y_t, node_mask,
    ):
        bs, n, dx = X_t.shape
        _, _, _, de = E_t.shape
        dt = (s - t)[0]

        noisy_data = {
            "X_t": X_t,
            "E_t": E_t,
            "y_t": y_t,
            "t": t,
            "node_mask": node_mask,
        }

        extra_data = self.compute_extra_data(noisy_data)
        pred = self.forward(noisy_data, extra_data, node_mask)

        pred_X = F.softmax(pred.X, dim=-1)
        pred_E = F.softmax(pred.E, dim=-1)

        limit_x = self.limit_dist.X
        limit_e = self.limit_dist.E

        if self.sample_type in ["linear_flow", "bw_flow"]:
            X_t_label = X_t.argmax(-1, keepdim=True)
            E_t_label = E_t.argmax(-1, keepdim=True)
            sampled_1 = utils.sample_discrete_features(
                pred_X, pred_E, node_mask=node_mask
            )
            X_1_pred = sampled_1.X
            E_1_pred = sampled_1.E

            R_t_X, R_t_E, _, _,  = (
                self.compute_discrete_velocity(
                    t, X_1_pred, E_1_pred, X_t_label, E_t_label,
                    pred_X, pred_E, node_mask,
                )
            )

            step_probs_X = R_t_X * dt
            step_probs_E = R_t_E * dt
            step_probs_X.scatter_(-1, X_t.argmax(-1)[:, :, None], 0.0).clamp(min=0.0, max=1.0)
            step_probs_E.scatter_(-1, E_t.argmax(-1)[:, :, :, None], 0.0).clamp(min=0.0, max=1.0)
            step_probs_X.scatter_(
                -1,
                X_t.argmax(-1)[:, :, None],
                (1.0 - step_probs_X.sum(dim=-1, keepdim=True)),
            )
            step_probs_E.scatter_(
                -1,
                E_t.argmax(-1)[:, :, :, None],
                (1.0 - step_probs_E.sum(dim=-1, keepdim=True)),
            )

            prob_X = step_probs_X.clamp(min=0.0).clone()
            prob_E = step_probs_E.clamp(min=0.0).clone()

        elif self.sample_type == "prob_path":
            At_batch = 1 - E_t[..., 0]
            A1_batch = 1 - pred_E[..., 0]

            eps_t = 1e-3
            s_float, t_float = s.view(bs, 1, 1), t.view(bs, 1, 1)

            denom = 1.0 - t_float
            denom = denom + eps_t

            w = (s_float - t_float) / denom
            w = torch.clamp(w, 0.0, 1.0)

            prob_X = X_t * (1.0 - w) + pred_X * w
            prob_E = self.mean_fn.mean_fn_batched(
                At_batch,
                A1_batch,
                w,
                ret_prob=True,
            )
            prob_E = (
                prob_E if self.cfg.model.mix_rate else prob_E.int().float()
            )
            prob_E = torch.stack((1 - prob_E, prob_E), dim=-1)

        if s[0] == 1.0:
            prob_X, prob_E = pred_X, pred_E

        sampled_s = utils.sample_discrete_features(
            prob_X, prob_E, node_mask=node_mask,
        )

        X_s = F.one_hot(sampled_s.X, num_classes=len(limit_x)).float()
        E_s = F.one_hot(sampled_s.E, num_classes=len(limit_e)).float()

        assert (E_s == torch.transpose(E_s, 1, 2)).all()
        assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

        y_to_save = torch.zeros([y_t.shape[0], 0], device=self.device)

        out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_to_save)
        out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_to_save)

        out_one_hot = out_one_hot.mask(node_mask).type_as(y_t)
        out_discrete = out_discrete.mask(node_mask, collapse=True).type_as(y_t)

        return out_one_hot, out_discrete


    def compute_discrete_velocity(
        self,
        t,
        X_1_pred,
        E_1_pred,
        X_t_label,
        E_t_label,
        pred_X,
        pred_E,
        node_mask,
        func="relu",
    ):

        (
            pt_vals_X, pt_vals_E,
            pt_vals_at_Xt, pt_vals_at_Et,
            dt_p_vals_X, dt_p_vals_E,
            dt_p_vals_at_Xt, dt_p_vals_at_Et,
        ) = self.linear_velocity.compute_pt_vals(t, X_t_label, E_t_label, X_1_pred, E_1_pred)

        if self.sample_type == 'bw_flow':
            limit_dist = self.limit_dist.to_device(self.device)
            X_1_label = pred_X.argmax(-1, keepdim=True)
            E_1_label = pred_E.argmax(-1, keepdim=True)
            X_1_onehot = F.one_hot(X_1_label.squeeze(-1), num_classes=len(limit_dist.X)).float()
            X_t_onehot = F.one_hot(X_t_label.squeeze(-1), num_classes=len(limit_dist.X)).float()
            E_0_label = self.initial_sample.E.argmax(-1, keepdim=True)

            Rstar_t_X, Rstar_t_E = self.BW_velocity.compute_velocity(
                t.unsqueeze(1),
                X_1_onehot, E_1_label,
                X_t_onehot, E_t_label, E_0_label
            )

        elif self.sample_type == 'linear_flow':
            Rstar_t_X, Rstar_t_E = self.linear_velocity.compute_Rstar(
                X_1_pred, E_1_pred,
                X_t_label, E_t_label,
                pt_vals_X, pt_vals_E,
                pt_vals_at_Xt, pt_vals_at_Et,
                dt_p_vals_X, dt_p_vals_E,
                dt_p_vals_at_Xt,  dt_p_vals_at_Et,
                func,
            )
        else:
            raise NotImplementedError()

        X_mask = torch.ones_like(pt_vals_X)
        E_mask = torch.ones_like(pt_vals_E)

        Rdb_t_X = pt_vals_X * X_mask * self.eta
        Rdb_t_E = pt_vals_E * E_mask * self.eta

        R_t_X, R_t_E = self.linear_velocity.compute_R(
            Rstar_t_X, Rstar_t_E,
            Rdb_t_X, Rdb_t_E,
            pt_vals_at_Xt, pt_vals_at_Et,
            pt_vals_X, pt_vals_E
        )

        return R_t_X, R_t_E, X_mask, E_mask

    def compute_extra_data(self, noisy_data):
        extra_features = self.extra_features(noisy_data)

        X, E, y = self.ref_dist.ignore_virtual_classes(
            noisy_data["X_t"],
            noisy_data["E_t"],
            noisy_data["y_t"],
        )
        noisy_data_to_mol_feat = noisy_data.copy()
        noisy_data_to_mol_feat["X_t"] = X
        noisy_data_to_mol_feat["E_t"] = E
        noisy_data_to_mol_feat["y_t"] = y

        extra_mol_features = self.domain_features(noisy_data_to_mol_feat)

        extra_X = torch.cat(
            (extra_features.X, extra_mol_features.X),
            dim=-1,
        )
        extra_E = torch.cat(
            (extra_features.E, extra_mol_features.E),
            dim=-1,
        )
        extra_y = torch.cat(
            (extra_features.y, extra_mol_features.y),
            dim=-1,
        )

        t = noisy_data["t"]
        extra_y = torch.cat((extra_y, t), dim=1)

        return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)



