import torch


class Args:
    def __init__(self):
        # device
        self.device = 0

        # model
        self.gnn = "gin-virtual"
        self.drop_ratio = 0.5
        self.num_layer = 3
        self.emb_dim = 256
        self.use_linear_predictor = False
        self.gamma = 0.4

        # training
        self.batch_size = 512
        self.epochs = 200
        self.patience = 50
        self.lr = 1e-3
        self.l2reg = 1e-8
        self.use_lr_scheduler = False
        self.use_clip_norm = False
        self.path_list = [1, 4]
        self.initw_name = "default"
        self.base_dir = "data_cache/polymer"

        # dataset
        self.trials = 2


def get_args():
    return Args()


class WMAELoss(torch.nn.Module):
    """Weighted Mean Absolute Error Loss for polymer properties"""

    def __init__(self):
        super(WMAELoss, self).__init__()
        # These values are from the train data.
        self.MINMAX_DICT = {
            "Tg": [-148.0297376, 472.25],
            "FFV": [0.2269924, 0.77709707],
            "Tc": [0.0465, 0.524],
            "Density": [0.748691234, 1.840998909],
            "Rg": [9.7283551, 34.672905605],
        }
        self.property_names = ["Tg", "FFV", "Tc", "Density", "Rg"]

        # Precompute property ranges for scaling
        self.property_ranges = torch.tensor(
            [
                self.MINMAX_DICT[prop][1] - self.MINMAX_DICT[prop][0]
                for prop in self.property_names
            ]
        )

    def forward(self, predictions, targets):
        """
        Calculate weighted MAE loss

        Args:
            predictions: tensor of shape (batch_size, 5)
            targets: tensor of shape (batch_size, 5)
        """
        device = predictions.device
        self.property_ranges = self.property_ranges.to(device)

        abs_errors = torch.abs(predictions - targets)

        scaled_errors = abs_errors / self.property_ranges.unsqueeze(0)

        valid_mask = ~torch.isnan(targets)

        valid_counts = valid_mask.sum(dim=0).float()
        property_weights = torch.sqrt(1.0 / (valid_counts + 1e-8))
        property_weights = (
            property_weights / property_weights.sum() * len(self.property_names)
        )

        property_maes = []
        total_weight = 0

        for i in range(len(self.property_names)):
            if valid_counts[i] > 0:
                valid_errors = scaled_errors[valid_mask[:, i], i]
                property_mae = valid_errors.mean()
                property_maes.append(property_mae * property_weights[i])
                total_weight += property_weights[i]

        if len(property_maes) == 0:
            return torch.tensor(0.0, device=device, requires_grad=True)

        wmae_loss = torch.stack(property_maes).sum() / total_weight
        return wmae_loss


criterion = WMAELoss()


def train_with_loss(args, model, device, loader, optimizers, optimizer_name):
    optimizer = optimizers[optimizer_name]
    model.train()
    if optimizer_name == "predictor":
        set_requires_grad([model.graph_encoder, model.predictor], requires_grad=True)
        set_requires_grad([model.separator], requires_grad=False)
    if optimizer_name == "separator":
        set_requires_grad([model.separator], requires_grad=True)
        set_requires_grad([model.graph_encoder, model.predictor], requires_grad=False)

    total_loss = 0
    num_batches = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            optimizer.zero_grad()
            pred = model(batch)

            target = batch.y.to(torch.float32)
            loss = criterion(pred["pred_rem"].to(torch.float32), target)
            target_rep = batch.y.to(torch.float32).repeat_interleave(
                batch.batch[-1] + 1, dim=0
            )
            loss += criterion(
                pred["pred_rep"].to(torch.float32),
                target_rep,
            )

            if optimizer_name == "separator":
                loss += pred["loss_reg"]

            total_loss += loss.item()
            num_batches += 1

            loss.backward()
            if args.use_clip_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    return total_loss / num_batches if num_batches > 0 else 0


def eval(args, model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model.eval_forward(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)


def init_weights(net, init_type="normal", init_gain=0.02):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, "weight") and (
            classname.find("Conv") != -1 or classname.find("Linear") != -1
        ):
            if init_type == "normal":
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "xavier":
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == "kaiming":
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif init_type == "orthogonal":
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
            elif init_type == "default":
                pass
            else:
                raise NotImplementedError(
                    "initialization method [%s] is not implemented" % init_type
                )
            if hasattr(m, "bias") and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif (
            classname.find("BatchNorm2d") != -1
        ):  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            torch.nn.init.normal_(m.weight.data, 1.0, init_gain)
            torch.nn.init.constant_(m.bias.data, 0.0)

    # print("initialize network with %s" % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
    Parameters:
        nets (network list)   -- a list of networks
        requires_grad (bool)  -- whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


# The code is from torch_scatter: https://github.com/rusty1s/pytorch_scatter/blob/1.3.0/torch_scatter/add.py
from itertools import repeat


def maybe_dim_size(index, dim_size=None):
    if dim_size is not None:
        return dim_size
    return index.max().item() + 1 if index.numel() > 0 else 0


def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    dim = range(src.dim())[dim]  # Get real dim value.

    # Automatically expand index tensor to the right dimensions.
    if index.dim() == 1:
        index_size = list(repeat(1, src.dim()))
        index_size[dim] = src.size(dim)
        index = index.view(index_size).expand_as(src)

    # Generate output tensor if not given.
    if out is None:
        out_size = list(src.size())
        dim_size = maybe_dim_size(index, dim_size)
        out_size[dim] = dim_size
        out = src.new_full(out_size, fill_value)

    return src, out, index, dim


def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/add.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
    specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
    each value in :attr:`src`, its output index is specified by its index in
    :attr:`input` for dimensions outside of :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension :attr:`dim`. If
    multiple indices reference the same location, their **contributions add**.

    Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
    size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
    :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
    size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
    values of :attr:`index` must be between `0` and `out.size(dim) - 1`.

    For one-dimensional tensors, the operation computes

    .. math::
        \mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j

    where :math:`\sum_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
        out (Tensor, optional): The destination tensor. (default: :obj:`None`)
        dim_size (int, optional): If :attr:`out` is not given, automatically
            create output with size :attr:`dim_size` at dimension :attr:`dim`.
            If :attr:`dim_size` is not given, a minimal sized output tensor is
            returned. (default: :obj:`None`)
        fill_value (int, optional): If :attr:`out` is not given, automatically
            fill output tensor with :attr:`fill_value`. (default: :obj:`0`)

    :rtype: :class:`Tensor`

    .. testsetup::

        import torch

    .. testcode::

        from torch_scatter import scatter_add

        src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        out = src.new_zeros((2, 6))

        out = scatter_add(src, index, out=out)

        print(out)

    .. testoutput::

       tensor([[0., 0., 4., 3., 3., 0.],
               [2., 4., 4., 0., 0., 0.]])
    """
    src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
    return out.scatter_add_(dim, index, src)
