from collections.abc import Mapping
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNet, NeuralNetClassifier, NeuralNetRegressor
from skorch.utils import to_tensor

class SampleWeightedClassifier(NeuralNetClassifier):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss_unreduced = super().get_loss(y_pred, y_true, X=None, training=False)
        sample_weight = to_tensor(X['sample_weight'], device=self.device)
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced


class RLearnerNN(NeuralNetClassifier):
    def __init__(self, *args, out_optim="SGD", **kwargs):
        super().__init__(*args, criterion=nn.BCELoss, **kwargs)
        self.nuisance_treat_optim_name = out_optim
        #self.nuisance_out_optim_name = prop_optim
        #self.nuisance_treat_feature_optim_name = prop_feature_optim
        self.curr_optimizer = None
        self.optimizer_name_mapping_ = {
            "propensity": "prop_optimizer",
            #"propensity_features": "prop_feat_optimizer",
            #"outcome": "out_optimizer",
            "tau": "optimizer",
        }


    def switch_optimizer(self, opt):
        if opt in self.optimizer_name_mapping_.keys():
            self.curr_optimizer = opt
        else:
            raise ValueError(f"Attempting to switch to unrecognized optimizer. Known optimizers are {self.optimizer_name_mapping_.keys()}")

    def initialize_optimizer(self):
        named_params = self.module_.propensity_nuisance.named_parameters()
        args, kwargs = self.get_params_for_optimizer("prop_optimizer", named_params)
        self.prop_optimizer_ = getattr(torch.optim, self.nuisance_treat_optim_name)(*args, **kwargs)

        #named_params = self.module_.propensity_feature_model.named_parameters()
        #args, kwargs = self.get_params_for_optimizer("prop_feat_optimizer", named_params)
        #self.prop_feat_optimizer_ = getattr(torch.optim, self.nuisance_treat_feature_optim_name)(*args, **kwargs)

        #named_params = self.module_.outcome_model.named_parameters()
        #args, kwargs = self.get_params_for_optimizer("out_optimizer", named_params)
        #self.prop_optimizer_ = getattr(torch.optim, self.nuisance_out_optim_name)(*args, **kwargs)

        named_params = self.module_.covariate_mapper.named_parameters()
        args, kwargs = self.get_params_for_optimizer("optimizer", named_params)
        self.optimizer_ = self.optimizer(*args, **kwargs)

    def _step_optimizer(self, step_fn):
        if self.curr_optimizer is None:
            raise RuntimeError("Must manually select an optimizer to update via `switch_optimizer()`.")
        name = self.optimizer_name_mapping_[self.curr_optimizer]
        optimizer = getattr(self, name + '_')
        if step_fn is None:
            optimizer.step()
        else:
            optimizer.step(step_fn) 
        

class DragonNetWrapper(NeuralNetRegressor):
    def __init__(self, *args, alpha=1.0, beta=1.0, eps=1.0e-3, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.beta = beta
        self.eps = eps
        
    def get_loss(self, y_pred, y_true, X=None, training=False):
        
        treat_out, head_outputs = y_pred
        y_true = to_tensor(y_true, device=self.device)
        t_true = to_tensor(X["T_"], device=self.device)
        #t_pred = torch.clip(treat_out, self.eps, 1 - self.eps)

        loss_t = torch.sum(F.cross_entropy(treat_out, t_true))
        loss_out = 0.
        for label in t_true.unique():
            label_mask = (t_true == label)
            loss_out += F.binary_cross_entropy(head_outputs[label][label_mask, 1], y_true[label_mask].float()) 
        final_loss = loss_out + self.alpha * loss_t
        return final_loss


class BaseNN(nn.Module):
    def __init__(self, input_size=20, hidden=300, return_logits=False): # original: 300
        super().__init__()
        self.return_logits = return_logits

        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden)
        self.fc2 = nn.Linear(in_features=hidden, out_features=hidden)
        self.fc3 = nn.Linear(in_features=hidden, out_features=hidden)
        self.out = nn.Linear(in_features=hidden, out_features=2)

    def forward(self, X_, **kwargs):
        z = F.relu(self.fc1(X_))
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        if self.return_logits:
            return self.out(z)
        else:
            return F.softmax(self.out(z), dim=-1)

class FeatureMapper(nn.Module):
    def __init__(self, input_size=20, hidden=300, feature_dim=10): # original: 300
        super().__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden)
        self.fc2 = nn.Linear(in_features=hidden, out_features=hidden)
        self.fc3 = nn.Linear(in_features=hidden, out_features=hidden)
        self.out = nn.Linear(in_features=hidden, out_features=feature_dim)

    def forward(self, X_, **kwargs):
        z = F.relu(self.fc1(X_))
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        return self.out(z)


class DragonNet(nn.Module):
    """
        Implementation based on https://github.com/farazmah/dragonnet-pytorch/blob/master/dragonnet/model.py
    """
    def __init__(self, input_size=20, hidden=300, outcome_hidden=100, treatment_hidden=300, n_treatments=10, return_logits=False, return_treatment_logits=False):
        super().__init__()
        self.return_logits = return_logits
        self.return_treatment_logits = return_treatment_logits

        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden)
        self.fc2 = nn.Linear(in_features=hidden, out_features=hidden)
        self.fc3 = nn.Linear(in_features=hidden, out_features=hidden)
        #self.treat_out = nn.Linear(in_features=hidden, out_features=n_treatments)
        self.heads = nn.ModuleList()
        
        self.treatment = nn.Sequential(
            nn.Linear(in_features=hidden, out_features=treatment_hidden),
            nn.ReLU(),
            nn.Linear(in_features=treatment_hidden, out_features=n_treatments),
        ) # we know treatment assignment is not linear in the covariates

        for _ in range(n_treatments):
            head = nn.Sequential(
                nn.Linear(in_features=hidden, out_features=outcome_hidden),
                nn.ReLU(),
                nn.Linear(in_features=outcome_hidden, out_features=outcome_hidden),
                nn.ReLU(),
                nn.Linear(in_features=outcome_hidden, out_features=2),
                nn.Softmax(dim=-1),
            )
            self.heads.append(head)

 
    def forward(self, X_, **kwargs):
        z = F.relu(self.fc1(X_))
        z = F.relu(self.fc2(z))
        z = F.relu(self.fc3(z))
        if self.return_treatment_logits:
            treat_out = self.treat_out(z) # logits
        else:
            treat_out = F.softmax(self.treatment(z), dim=-1)
       
        head_outputs = [head(z) for head in self.heads]

        return treat_out, head_outputs # no tarreg -- the point is to check a multimask approach


class TreatmentEmbedder(nn.Module):
    def __init__(self, input_size=10, hidden=60):
        super().__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=1)

    def forward(self, x):
        out = self.fc1(x)
        return out


class RLearnerWrapper(nn.Module):
    def __init__(self, input_size=100, treatment_input_size=10, hidden=300, feature_dim=10):
        super().__init__()
        self.outcome_nuisance = None  # BaseNN(input_size=input_size, hidden=hidden)
        #self.propensity_feature_nuisance = FeatureMapper(input_size=treatment_input_size, feature_dim=feature_dim, hidden=hidden) # h(T)
        self.prop_feat_model = None
        self.propensity_nuisance = FeatureMapper(input_size=treatment_input_size, feature_dim=feature_dim, hidden=hidden) # h(T)
        self.covariate_mapper = FeatureMapper(input_size=input_size, feature_dim=feature_dim, hidden=hidden) # tau

    def forward(self, X_, T_, **kwargs):
        with torch.no_grad():
            outcome = torch.from_numpy(self.outcome_nuisance.predict_proba(X_)).to(X_.device)
            propensity_features = torch.from_numpy(self.prop_feat_model.predict(X_)).to(X_.device)
        propensity_nuisance = self.propensity_nuisance(T_)
        covariate_features = self.covariate_mapper(X_)
        Y_ = (covariate_features * (propensity_nuisance - propensity_features)).sum(dim=-1) + outcome[:, 1]
        return torch.clip(Y_, 0., 1.)

    def attach_propensity_featurizer(self, model):
        self.prop_feat_model = model

    def attach_outcome_nuisance(self, model):
        self.outcome_nuisance = model
