from typing import Optional
import torch
import numpy as np
from torch import nn
from itertools import chain
from torch.utils.tensorboard import SummaryWriter
from src.utils.mi_estimators import *
from src.utils.pytorch_linear_reg_utils import fit_linear, linear_reg_pred, outer_prod, add_const_col, fit_weighted_linear, linear_reg_loss
from src.data.data_class import TrainDataSetTorch, TestDataSetTorch, concat_dataset

variance = 1

class TSSIModel:
    stage1_weight: torch.Tensor
    stage2_y1_weight: torch.Tensor
    stage2_y0_weight: torch.Tensor

    def __init__(self,
                 treatment_net: nn.Module,
                 instrumental_net: nn.Module,
                 selection_net: nn.Module,
                 covariate_net: Optional[nn.Module],
                 odds_net: nn.Module,
                 phis_net: nn.Module,
                 phit_net: nn.Module,
                 S_net: nn.Module,
                 s_net: nn.Module,
                 y_net: nn.Module,
                 y1_net: nn.Module,
                 h2_net: nn.Module, 
                 add_stage1_intercept: bool,
                 add_stage2_intercept: bool,
                 odds_iter: int,
                 selection_weight_decay: float,
                 odds_weight_decay: float,
                 y_weight_decay: float,
                 y1_weight_decay: float,
                 lam_y: float,
                 distance_dim: int
                 ):
        self.treatment_net = treatment_net
        self.instrumental_net = instrumental_net
        self.selection_net = selection_net
        self.odds_net = odds_net
        self.covariate_net = covariate_net
        self.phis_net = phis_net
        self.phit_net = phit_net
        self.S_net = S_net
        self.s_net = s_net
        self.y_net = y_net
        self.y1_net = y1_net
        self.h2_net = h2_net
        self.add_stage1_intercept = add_stage1_intercept
        self.add_stage2_intercept = add_stage2_intercept
        self.odds_iter = odds_iter
        self.odds_opt = torch.optim.Adam(self.odds_net.parameters(),
                                         weight_decay=odds_weight_decay)
        self.y_opt = torch.optim.Adam(self.y_net.parameters(),
                                       weight_decay=y_weight_decay)
        self.y1_opt = torch.optim.Adam(self.y1_net.parameters(),
                                       weight_decay=y1_weight_decay)
        self.selection_opt = torch.optim.Adam(self.selection_net.parameters(),
                                        weight_decay=selection_weight_decay)
        self.S_opt = torch.optim.Adam(self.S_net.parameters(),
                                        weight_decay=0)
        self.h2_opt = torch.optim.Adam(self.h2_net.parameters(),
                                        weight_decay=0)
        self.condition_dim = 2
        self.lam_y = lam_y
        self.distance_dim = distance_dim

    @staticmethod
    def augment_stage1_feature(instrumental_feature: torch.Tensor,
                               add_stage1_intercept: bool):

        feature = instrumental_feature
        if add_stage1_intercept:
            feature = add_const_col(feature)
        return feature
    
    def augment_stage1_feature_plus(instrumental_feature: torch.Tensor,
                                    covariate_feature: torch.Tensor,
                               add_stage1_intercept: bool):

        feature = instrumental_feature
        if add_stage1_intercept:
            feature = add_const_col(feature)
        if covariate_feature is not None:
            feature_tmp = covariate_feature
            if add_stage1_intercept:
                feature_tmp = add_const_col(feature_tmp)
            feature = outer_prod(feature, feature_tmp)
            feature = torch.flatten(feature, start_dim=1)

        return feature

    @staticmethod
    def augment_stage2_feature(treatment_feature: torch.Tensor,
                               covariate_feature: torch.Tensor,
                               add_stage2_intercept: bool):
        feature = treatment_feature
        if add_stage2_intercept:
            feature = add_const_col(feature)
        if covariate_feature is not None:
            feature_tmp = covariate_feature
            if add_stage2_intercept:
                feature_tmp = add_const_col(feature_tmp)
            feature = outer_prod(feature, feature_tmp)
            feature = torch.flatten(feature, start_dim=1)

        return feature

    @staticmethod
    def augment_stage_y1_feature(treatment_feature: torch.Tensor,
                                 residual: torch.Tensor,
                                 covariate_feature: Optional[torch.Tensor],
                                 add_stage2_intercept: bool):
        feature = treatment_feature
        if add_stage2_intercept:
            feature = add_const_col(feature)

        if covariate_feature is not None:
            feature_tmp = covariate_feature
            if add_stage2_intercept:
                feature_tmp = add_const_col(feature_tmp)
            feature = outer_prod(feature, feature_tmp)
            feature = torch.flatten(feature, start_dim=1)

        feature = torch.cat((feature, residual), 1)

        return feature


    
    @staticmethod
    def generate_noise(n, variance):
        std_dev = np.sqrt(variance)
        noise = np.random.normal(0, std_dev, (n, 1))
        noise = torch.from_numpy(noise).float()
        return noise.cuda()

    @staticmethod
    def fit_2sls(treatment_1st_feature: torch.Tensor,
                 treatment_2nd_feature: torch.Tensor,
                 instrumental_1st_feature: torch.Tensor,
                 instrumental_2nd_feature: torch.Tensor,
                 phis_2nd_feature: torch.Tensor,
                 covariate_1st_feature: Optional[torch.Tensor],
                 covariate_2nd_feature: Optional[torch.Tensor],
                 outcome_2nd_t: torch.Tensor,
                 lam1: float, lam2: float,
                 add_stage1_intercept: bool,
                 add_stage2_intercept: bool,
                 ):
        # stage1
        feature = TSSIModel.augment_stage1_feature(instrumental_1st_feature, add_stage1_intercept)
        stage1_weight = fit_linear(treatment_1st_feature, feature, lam1)
        # predicting for stage 2
        feature = TSSIModel.augment_stage1_feature(instrumental_2nd_feature, add_stage1_intercept)
        predicted_treatment_2nd_feature = linear_reg_pred(feature, stage1_weight) # T^

        # stage2 - y1 regression
        feature = TSSIModel.augment_stage_y1_feature(predicted_treatment_2nd_feature,
                                                     phis_2nd_feature,
                                                     covariate_2nd_feature,
                                                     add_stage2_intercept)

        stage2_weight = fit_linear(outcome_2nd_t, feature, lam2)
        pred = linear_reg_pred(feature, stage2_weight) # Y^
        stage2_loss = torch.norm((outcome_2nd_t - pred)) ** 2 + lam2 * torch.norm(stage2_weight) ** 2

        return dict(stage1_weight=stage1_weight,
                    stage2_weight=stage2_weight,
                    stage2_loss=stage2_loss)

    def fit_odds(self,
                 instrumental_feature: torch.Tensor,
                 phis_1st_feature: torch.Tensor, 
                 phis_2nd_feature: torch.Tensor, 
                 phis_3rd_feature: torch.Tensor, 
                 covariate_1st: Optional[torch.Tensor],
                 covariate_2nd: Optional[torch.Tensor],
                 covariate_3rd: Optional[torch.Tensor],
                 selection_probability: torch.Tensor, 
                 outcome_2nd_t: torch.Tensor,
                 ):
        
        feature = TSSIModel.augment_stage1_feature(instrumental_feature, self.add_stage1_intercept)
        predicted_treatment_feature = linear_reg_pred(feature, self.stage1_weight)
        
        self.h2_net.train(True)
        for i in range(self.odds_iter):
            self.h2_opt.zero_grad()
            h2_feature = covariate_2nd.detach()
            h2_pred = self.h2_net(h2_feature)
            h1_norm_2nd = phis_2nd_feature
            loss = torch.sum((h2_pred - h1_norm_2nd) ** 2)
            loss.backward()
            self.h2_opt.step()
        self.h2_net.train(False)

        self.S_net.train(True)
        for e in range(self.odds_iter):
            self.S_opt.zero_grad()
            S_feature = torch.cat((covariate_2nd, outcome_2nd_t, phis_2nd_feature, predicted_treatment_feature), 1).detach()
            S_pred = self.S_net(S_feature)
            pi = S_pred
            pi = torch.max(pi, torch.tensor(0.1))
            S_loss_1 = (torch.sum((1 / pi - 1) * covariate_2nd[:, 0:1]) - torch.sum(covariate_3rd[:, 0:1])) ** 2
            S_loss_2 = (torch.sum((1 / pi - 1) * covariate_2nd[:, 1:2]) - torch.sum(covariate_3rd[:, 1:2])) ** 2
            S_loss_3 = (torch.sum((1 / pi - 1) * phis_2nd_feature) - torch.sum(phis_3rd_feature)) ** 2
            S_loss = S_loss_1 + S_loss_2 + S_loss_3
            S_loss = torch.sum(S_loss)
            h2_pred = self.h2_net(covariate_2nd)
            or_loss = torch.sum(((1 / pi) * (phis_2nd_feature - h2_pred) * outcome_2nd_t) ** 2)
            loss =  S_loss + or_loss
            loss.backward()
            self.S_opt.step()

        S_feature = torch.cat((covariate_2nd, outcome_2nd_t, phis_2nd_feature, predicted_treatment_feature), 1).detach()
        p = self.S_net(S_feature).detach()
        p = torch.max(p, torch.tensor(0.1))
        selection_probability = torch.max(selection_probability, torch.tensor(0.1))
        W = 1 / p

        return W


    def fit_outcome(self,
                    covariate: torch.Tensor,
                    predicted_treatment_2nd_feature: torch.Tensor, 
                    phis_2nd_feature: torch.Tensor,
                    outcome_2nd_t: torch.Tensor,
                    W: torch.Tensor):
        self.treatment_net.train(False)
        self.instrumental_net.train(False)
        self.covariate_net.train(False)
        self.phis_net.train(False)
        self.phit_net.train(False)
        self.y_net.train(True) 
        self.y1_net.train(True)
        self.selection_net.train(False)
        self.S_net.train(False)
        self.odds_net.train(False) 
        for e in range(self.odds_iter):
            covariate_2nd_feature = self.covariate_net(covariate).detach()
            feature = torch.cat((predicted_treatment_2nd_feature, phis_2nd_feature, covariate_2nd_feature), 1).detach()
            self.y_net.zero_grad()
            outcome_y = self.y_net(feature)
            loss_y = torch.sum(W * (outcome_2nd_t - outcome_y) ** 2)
            loss_y.backward()
            self.y_opt.step()

    def fit_t(self,
              train_1st_data_t: TrainDataSetTorch,
              train_2nd_data_t: TrainDataSetTorch,
              train_3rd_data_t: TrainDataSetTorch,
              lam1: float, lam2: float, lam3: float):
        self.treatment_net.train(False)
        self.covariate_net.train(False)
        self.instrumental_net.train(False)
        treatment_1st_feature = self.treatment_net(train_1st_data_t.treatment).detach()
        treatment_2nd_feature = self.treatment_net(train_2nd_data_t.treatment).detach()
        treatment_3rd_feature = self.treatment_net(train_3rd_data_t.treatment).detach()
        train_1st_t_phit_feature = self.phit_net(train_1st_data_t.instrumental).detach()
        train_2nd_t_phit_feature = self.phit_net(train_2nd_data_t.instrumental).detach()
        train_3rd_t_phit_feature = self.phit_net(train_3rd_data_t.instrumental).detach()
        instrumental_1st_feature = self.instrumental_net(torch.cat((train_1st_t_phit_feature, train_1st_data_t.covariate), 1)).detach()
        instrumental_2nd_feature = self.instrumental_net(torch.cat((train_2nd_t_phit_feature, train_2nd_data_t.covariate), 1)).detach()
        instrumental_3rd_feature = self.instrumental_net(torch.cat((train_3rd_t_phit_feature, train_3rd_data_t.covariate), 1)).detach()
        selection_probability = train_2nd_data_t.selection_probability
        outcome_2nd_t = train_2nd_data_t.outcome
        selection_1st_d = train_1st_data_t.selection
        phis_1st_feature = self.phis_net(train_1st_data_t.instrumental).detach()
        phis_2nd_feature = self.phis_net(train_2nd_data_t.instrumental).detach()
        phis_3rd_feature = self.phis_net(train_3rd_data_t.instrumental).detach()
        covariate_1st_feature = None
        covariate_2nd_feature = None
        covariate_1st = train_1st_data_t.covariate
        covariate_2nd = train_2nd_data_t.covariate
        covariate_3rd = train_3rd_data_t.covariate
        if self.covariate_net is not None:
            covariate_1st_feature = self.covariate_net(train_1st_data_t.covariate).detach()
            covariate_2nd_feature = self.covariate_net(train_2nd_data_t.covariate).detach()
            covariate_3rd_feature = self.covariate_net(train_3rd_data_t.covariate).detach()
        res = TSSIModel.fit_2sls(treatment_1st_feature,
                                 treatment_2nd_feature,
                                 instrumental_1st_feature,
                                 instrumental_2nd_feature,
                                 phis_2nd_feature,
                                 covariate_1st_feature,
                                 covariate_2nd_feature,
                                 outcome_2nd_t,
                                 lam1, lam2,
                                 self.add_stage1_intercept,
                                 self.add_stage2_intercept)
        self.stage1_weight = res["stage1_weight"]
        self.stage2_y1_weight = res["stage2_weight"]
        # predict for stage 2 odds
        feature = TSSIModel.augment_stage1_feature(instrumental_1st_feature, self.add_stage1_intercept)
        predicted_treatment_1st_feature = linear_reg_pred(feature, self.stage1_weight)
        feature = TSSIModel.augment_stage1_feature(instrumental_2nd_feature, self.add_stage1_intercept)
        predicted_treatment_2nd_feature = linear_reg_pred(feature, self.stage1_weight)
        feature = TSSIModel.augment_stage1_feature(instrumental_3rd_feature, self.add_stage1_intercept)
        predicted_treatment_3rd_feature = linear_reg_pred(feature, self.stage1_weight)

        feature = TSSIModel.augment_stage_y1_feature(predicted_treatment_2nd_feature,
                                                     phis_2nd_feature,
                                                     covariate_2nd_feature,
                                                     self.add_stage2_intercept)
        selection_probability = train_2nd_data_t.selection_probability
        W = self.fit_odds(instrumental_2nd_feature,
                          phis_1st_feature,
                          phis_2nd_feature,
                          phis_3rd_feature,
                          covariate_1st,
                          covariate_2nd,
                          covariate_3rd,
                          selection_probability,
                          outcome_2nd_t
                          )
        self.fit_outcome(covariate_2nd, 
                         predicted_treatment_2nd_feature,
                         phis_2nd_feature,
                         outcome_2nd_t,
                         W)

    def predict_t(self, treatment: torch.Tensor, covariate: Optional[torch.Tensor],
                  instrumental: Optional[torch.Tensor]):
        treatment_feature = self.treatment_net(treatment)
        covariate_feature = None
        phis_feature = self.phis_net(instrumental)
        phit_feature = self.phit_net(instrumental)
        instrumental_feature = self.instrumental_net(torch.cat((phit_feature, covariate), 1)).detach()
        if self.covariate_net:
            covariate_feature = self.covariate_net(covariate)
        if instrumental is not None:
            feature = TSSIModel.augment_stage1_feature(instrumental_feature, self.add_stage1_intercept)
            predicted_treatment_feature = linear_reg_pred(feature, self.stage1_weight)
        else:
            condition_feature = torch.zeros((len(treatment), self.condition_dim))


        feature = torch.cat((predicted_treatment_feature, phis_feature, covariate_feature), 1) 
        return self.y_net(feature)

    def evaluate_t(self, test_data: TestDataSetTorch):
        target = test_data.structural
        with torch.no_grad():
            pred = self.predict_t(test_data.treatment, test_data.covariate, test_data.instrumental)
        res = (torch.norm((target - pred)) ** 2) / target.size()[0]
        return res.detach().cpu().numpy()