import abc
from abc import ABC
from typing import List

import torch
from torch import nn

from models.ClassificationModel import ClassProbabilities
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import ClassificationProbabilityEstimatorWithProxy
from models.networks import BaseModel


class NetworkProbabilityEstimatorWithProxy(ClassificationProbabilityEstimatorWithProxy):
    def __init__(self, dataset_name, saved_models_path, x_dim: int, n_classes: int,
                 hidden_dims: List[int] = None, dropout: float = 0.1,
                 batch_norm: bool = False,
                 lr: float = 1e-3, wd: float = 0., device='cpu', figures_dir=None,
                 seed=0):
        ClassificationProbabilityEstimatorWithProxy.__init__(self, dataset_name, saved_models_path, figures_dir, seed)

        if hidden_dims is None:
            hidden_dims = [32, 64, 64, 32]
        self._network = BaseModel(x_dim + 1, n_classes, hidden_dims=hidden_dims, dropout=dropout,
                                  batch_norm=batch_norm).to(device)
        self._optimizer = torch.optim.Adam(self.network.parameters(), lr=lr, weight_decay=wd)
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.lr = lr
        self.wd = wd

    def loss(self, y: torch.Tensor, prediction, d, epoch, **kwargs):
        return self.cross_entropy_loss(prediction, y.long()[:, 1])

    def predict(self, x, **kwargs):
        return self.network(x).squeeze()

    @property
    def name(self) -> str:
        return "network_pe"

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        model_input = torch.cat([x, y[:, 0].unsqueeze(-1)], dim=-1)
        return ClassProbabilities(torch.softmax(self.network.forward(model_input), -1))

    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        x_train = torch.cat([x_train, y_train[:, 0].unsqueeze(-1)], dim=-1)
        x_val = torch.cat([x_val, y_val[:, 0].unsqueeze(-1)], dim=-1)
        super().fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait)

    
