from model.modules import LatentModule, StoNet_Model
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import recall_score


class StoNet_Proxy(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim, confounder_layer, treat_layer, treat_node, 
                 sigma_list=None, CE_treat_weight=None):
        super(StoNet_Proxy, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.hidden_dim = hidden_dim
        self.treat_layer = treat_layer
        self.treat_node = treat_node
        self.confounder_layer = confounder_layer

        layer_dim = hidden_dim.copy()
        layer_dim.append(output_dim)

        self.module_xz = LatentModule(input_dim, layer_dim[:confounder_layer+1], first_module=True)
        self.module_za = LatentModule(layer_dim[confounder_layer], layer_dim[confounder_layer+1:treat_layer+1])
        self.module_ay = LatentModule(layer_dim[treat_layer], layer_dim[treat_layer+1:])

        self.module_list = nn.ModuleList([self.module_xz, self.module_za, self.module_ay])

        if sigma_list is not None:
            self.sigma_z = sigma_list[0]
            self.sigma_list = sigma_list
        
        # loss function for imputation
        self.sse = nn.MSELoss(reduction='sum')

        if isinstance(self.treat_node, (list, tuple, np.ndarray)):
            # multiple level treatment
            self.treat_loss = nn.CrossEntropyLoss(weight=CE_treat_weight, reduction='sum')
        else:
            # binary treatment
            self.treat_loss = nn.BCEWithLogitsLoss(pos_weight=CE_treat_weight, reduction='sum')
        
    def forward(self, x, treat):
        z = self.module_xz(x)
        a_temp = self.module_za(z)
        logits = torch.clone(a_temp[:, self.treat_node])
        if isinstance(self.treat_node, (list, tuple, np.ndarray)):
            ps = torch.softmax(logits, dim=1)
        else:
            ps = torch.sigmoid(logits)
        a_temp[:, self.treat_node] = treat
        y = self.module_ay(a_temp)

        return y, ps
    
    def likelihood_latent(self, hidden_list, module_index, outcome_loss, y, input=None, forward_z=None, treat_loss_weight=1):
        if module_index == 0:  # log_likelihood(Z|X)
            if forward_z is not None:  # for imputation
                likelihood = -self.sse(forward_z, hidden_list[module_index]) / (2 * self.sigma_list[module_index])
            elif input is not None:  # for parameter update
                likelihood = -self.sse(self.module_list[module_index](input), hidden_list[module_index]) / (2 * self.sigma_list[module_index])

        elif module_index == 1:  # log_likelihood(A|Z)
            h = self.module_list[module_index](hidden_list[module_index - 1])

            h_treat = h[:, self.treat_node]
            treat = hidden_list[module_index][:, self.treat_node]
            likelihood_treat = -self.treat_loss(h_treat, treat) * treat_loss_weight

            if isinstance(self.treat_node, (list, tuple, np.ndarray)):
                idx_start = self.treat_node[0]
                idx_end = self.treat_node[-1]
            else:
                idx_start = self.treat_node
                idx_end = self.treat_node

            h_rest_1 = h[:, 0:idx_start]
            temp1 = hidden_list[module_index][:, 0:idx_start]
            likelihood_rest_1 = -self.sse(h_rest_1, temp1)/(2 * self.sigma_list[module_index])

            h_rest_2 = h[:, idx_end + 1:]
            temp2 = hidden_list[module_index][:, idx_end + 1:]
            likelihood_rest_2 = -self.sse(h_rest_2, temp2)/(2 * self.sigma_list[module_index])

            likelihood = likelihood_treat + likelihood_rest_1 + likelihood_rest_2

        else:  # log_likelihood(Y|A, Z)
            likelihood = -outcome_loss(self.module_list[module_index](hidden_list[module_index - 1]), y) / (
                    2 * self.sigma_list[-1])

        return likelihood


    def backward_imputation(self, mh_step, impute_lrs, outcome_loss, x, treat, y, itas, treat_loss_weight=1):
        # initialize momentum term and hidden unit
        hidden_list, momentum_list = [], []
        # for Z
        hidden_list.append(self.module_xz(x).detach())
        momentum_list.append(torch.zeros_like(hidden_list[-1]))
        # for A
        hidden_list.append(self.module_za(hidden_list[-1]).detach())
        momentum_list.append(torch.zeros_like(hidden_list[-1]))
        hidden_list[-1][:, self.treat_node] = treat

        for i in range(len(hidden_list)):
            hidden_list[i].requires_grad = True
        with torch.no_grad():
            forward_z = torch.clone(hidden_list[0])

        # backward imputation by SGHMC
        for step in range(mh_step):
            # hidden units imputation
            for module_index in reversed(range(len(hidden_list))):
                hidden_list[module_index].grad = None

                hidden_likelihood1 = self.likelihood_latent(hidden_list=hidden_list, module_index=module_index + 1,
                                                            outcome_loss=outcome_loss, y=y, forward_z=forward_z,
                                                            treat_loss_weight=treat_loss_weight)
                hidden_likelihood2 = self.likelihood_latent(hidden_list=hidden_list, module_index=module_index,
                                                            outcome_loss=outcome_loss, y=y, forward_z=forward_z,
                                                            treat_loss_weight=treat_loss_weight)

                hidden_likelihood1.backward()
                hidden_likelihood2.backward()

                lr = impute_lrs[module_index]
                ita = itas[module_index]
                alpha = lr * ita

                with torch.no_grad():
                    momentum_list[module_index] = (1 - alpha) * momentum_list[module_index] + lr * hidden_list[
                        module_index].grad + torch.FloatTensor(hidden_list[module_index].shape).to(self.device).normal_().mul(
                        np.sqrt(2*alpha))
                    if module_index == 1:
                        # treatment node will not be updated
                        momentum_list[module_index][:, self.treat_node] = torch.zeros_like(treat)

                    hidden_list[module_index].data += lr * momentum_list[module_index]

        return hidden_list


class StoNet_Proxy_Model(StoNet_Model):
    def __init__(self, hidden_dim, input_dim, output_dim, confounder_layer, treat_layer, treat_node,
                    seed, log_dir = None, sigma_list = None, outcome_cat=False):
        np.random.seed(seed)
        torch.manual_seed(seed)
        model = StoNet_Proxy(hidden_dim, input_dim, output_dim, confounder_layer, treat_layer, treat_node, sigma_list)
        model.to(model.device)
        self.writer = SummaryWriter(log_dir=log_dir)

        super().__init__(model, outcome_cat)
    
    def train_prep(self, treat_loss_weight, para_lrs_train, para_lrs_fine_tune):
        self._train_prep(para_lrs_train, para_lrs_fine_tune)

        self.input_gamma_path.update(var_selected_treat={}, num_selected_treat=[])
        
        self.backward_imputation_args.update(treat_loss_weight=treat_loss_weight)
        self.likelihood_latent_args.update(treat_loss_weight=treat_loss_weight)

        # treatment loss function
        if isinstance(self.model.treat_node, (list, tuple, np.ndarray)):
            self.treat_loss = nn.CrossEntropyLoss()
        else:
            self.treat_loss = nn.BCELoss()

    
    def train(self, mode, train_data, val_data, epochs, batch_size,
               impute_lrs, mh_step, itas, para_lrs_train, para_lrs_fine_tune, 
               para_lr_decay, impute_lr_decay,
               prior_sigma_0, prior_sigma_1, treat_loss_weight, lambda_n, 
               y_scale=1, CE_weight=None):
        
        self.train_prep(treat_loss_weight, para_lrs_train, para_lrs_fine_tune)
        self._train(mode, train_data, val_data, epochs, batch_size,
               impute_lrs, mh_step, itas, para_lr_decay, impute_lr_decay,
               prior_sigma_0, prior_sigma_1, lambda_n, 
               CE_weight, y_scale)

        self.writer.flush()
    
    def performance_eval(self, eval_set, data, loss_func, y_scale, epoch, mode):
        out_loss, out_correct, treat_loss, treat_correct, recall_neg, recall_pos = self.calculate_loss(data, loss_func, y_scale)
        
        if eval_set == 'train':
            self.writer.add_scalar(mode + ": train_loss", out_loss, epoch)
            print(f"Avg train loss: {out_loss:>8f} \n")

            self.writer.add_scalar(mode + ": treat_train_loss", treat_loss, epoch)
            self.writer.add_scalar(mode + ": treat_train_acc", treat_correct, epoch)

            if self.outcome_cat:
                self.writer.add_scalar(mode + ": train_acc", out_correct, epoch)
                print(f"train accuracy: {out_correct:>8f} \n")

        elif eval_set == 'val':
            self.writer.add_scalar(mode + ": val_loss", out_loss, epoch)
            print(f"Avg val loss: {out_loss:>8f} \n")

            self.writer.add_scalar(mode + ": treat_val_loss", treat_loss, epoch)
            self.writer.add_scalar(mode + ": treat_val_acc", treat_correct, epoch)

            if self.outcome_cat:
                self.writer.add_scalar(mode + ": val_acc", out_correct, epoch)
                print(f"val accuracy: {out_correct:>8f} \n")
                self.writer.add_scalar(mode + ": val_recall_neg", recall_neg, epoch)
                self.writer.add_scalar(mode + ": val_recall_pos", recall_pos, epoch)
        return out_loss, out_correct, treat_loss, treat_correct


    def calculate_loss(self, data, loss_func, y_scale):
        # calculate performance
        out_loss, out_correct, treat_loss, treat_correct, recall_neg, recall_pos = 0, 0, 0, 0, 0, 0
        with torch.no_grad():
            for y, treat, x, *rest in data:
                pred, ps = self.predict(x, treat)
                out_loss += loss_func(pred, y).item()
                treat_loss += self.treat_loss(ps, treat).item()  # note that for BCELoss the input has to be probability
                if isinstance(self.model.treat_node, (list, tuple, np.ndarray)):
                    treat_correct += (ps.argmax(dim=1) == treat.argmax(dim=1)).sum().item()
                else:
                    treat_correct += ((ps > 0.5) == treat).sum().item()
                if self.outcome_cat:
                    pred = pred.argmax(1)
                    out_correct += (pred == y).type(torch.float).sum().item()
                    # recall - note that this is the recall for the last batch
                    recall = recall_score(y, pred, average=None).tolist()
                    recall_neg = recall[0]
                    recall_pos = recall[1]

        if self.outcome_cat is False:
            # use RMSE as the model performance metric for regression tasks
            out_loss = np.sqrt(out_loss/len(data.dataset)) * y_scale
        else:
            out_loss /= len(data.dataset)

        if self.outcome_cat:
            out_correct /= len(data.dataset)

        treat_loss /= len(data.dataset)
        treat_correct /= len(data.dataset)

        return out_loss, out_correct, treat_loss, treat_correct, recall_neg, recall_pos
    
    def selected_variable(self, epoch):
        var_ind_out = np.identity(len(self.para_gamma_path[str(epoch)]['0.module.0.weight'][0]), dtype=bool)
        for i, (name, para) in enumerate(self.model.module_list.named_parameters()):
            if i % 2 == 0:
                var_ind_out = np.matmul(self.para_gamma_path[str(epoch)][name], var_ind_out)
                if i/2 == self.model.treat_layer:
                    var_ind_treat = np.copy(var_ind_out[self.model.treat_node, :])
        var_ind_out = np.max(var_ind_out, 0)
        if isinstance(self.model.treat_node, (list, tuple, np.ndarray)):
            var_ind_treat = np.prod(var_ind_treat, axis=0)
        num_selected_out = np.sum(var_ind_out)
        num_selected_treat = np.sum(var_ind_treat)
        self.input_gamma_path['var_selected'][str(epoch)] = var_ind_out.tolist()
        self.input_gamma_path['var_selected_treat'][str(epoch)] = var_ind_treat.tolist()
        self.input_gamma_path['num_selected'].append(num_selected_out.astype("float64"))
        self.input_gamma_path['num_selected_treat'].append(num_selected_treat.astype("float64"))
        print('number of selected input variable for outcome:', num_selected_out)
        print('number of selected input variable for treatment:', num_selected_treat)
    
    def predict(self, x, treat):
        if self.prune_flag == 1:
            self.prune_masked_para()
        
        return self.model(x, treat)

    def get_y_cf(self, x, z_sd, outcome_cat, sample_size=100, seed=1):
        """
        predict the full counterfactuals

        x: observed proxy variable
        z_sd: the standard deviation of p(z|x), can be either specified or estimated
        outcome_cat: the type of outcome variable, binary or continuous
        seed: seed to control randomness.
        """
        if self.prune_flag == 1:
            self.prune_masked_para()

        # sample z form P(z|x)
        z = self.model.module_xz(x)

        np.random.seed(seed)
        torch.manual_seed(seed)
        normal = Normal(loc=z, scale=z_sd)
        z = normal.sample((sample_size,))  # dimension of z: sample_size * batch * dim_confounder_layer

        # list of all possible treatments
        treat_list = []
        if isinstance(self.model.treat_node, (list, tuple, np.ndarray)):
            treat_node = self.model.treat_node
        else:
            treat_node = [self.model.treat_node]

        for i in range(len(treat_node)+1):
            treat = torch.zeros(size=z.size()[:2] + torch.Size(treat_node))
            if i < len(treat_node):
                treat[:, :, i] = 1
            treat_list.append(treat)    # the last element in y_list is potential outcome for the control group

        # sample y_cf
        y_list = []
        for treat in treat_list:
            a_temp = self.model.module_za(z)
            a_temp[:, :, treat_node] = treat
            y_cf = self.model.module_ay(a_temp)

            if outcome_cat:
                m = nn.Softmax(dim=2)
                y_cf = m(y_cf).mean(dim=0)
            else:
                y_cf = y_cf.mean(dim=0)

            y_list.append(y_cf)

        return y_list