from __future__ import annotations
from itertools import chain
from typing import Dict, List, Any, Optional
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import sys
import numpy

from src.utils.mi_estimators import *

from src.TSSI.model import TSSIModel
from src.data.data_generation import demand
from src.data.data_class import TrainDataSetTorch, TestDataSetTorch, concat_dataset
from src.utils.pytorch_linear_reg_utils import linear_reg_loss, fit_linear, linear_reg_pred, linear_reg_weight_loss
from config import Config

exp_num = 0

class TSSITrainer(object):

    def __init__(self, networks: List[Any], train_params: Dict[str, Any],
                 gpu_flg: bool = False):
        self.gpu_flg = gpu_flg and torch.cuda.is_available()
        # configure training params
        self.lam1: float = train_params["lam1"]
        self.lam2: float = train_params["lam2"]
        self.lam3: float = train_params["lam3"]
        self.lam4: float = train_params["lam4"]
        self.distance_dim: int = train_params["distance_dim"]
        self.stage1_iter: int = train_params["stage1_iter"]
        self.stage1_S1_iter: int = train_params["stage1_S1_iter"]
        self.stage2_iter: int = train_params["stage2_iter"]
        self.covariate_iter: int = train_params["covariate_iter"]
        self.mi_iter: int = train_params["mi_iter"]
        self.odds_iter: int = train_params["odds_iter"]
        self.n_epoch: int = train_params["n_epoch"]
        self.epoch: int = train_params["epoch"]
        self.add_stage1_intercept: bool = True
        self.add_stage2_intercept: bool = True
        self.treatment_weight_decay: float = train_params["treatment_weight_decay"]
        self.instrumental_weight_decay: float = train_params["instrumental_weight_decay"]
        self.covariate_weight_decay: float = train_params["covariate_weight_decay"]
        self.selection_weight_decay: float = train_params["selection_weight_decay"]
        self.r1_weight_decay: float = train_params["r1_weight_decay"]
        self.r0_weight_decay: float = train_params["r0_weight_decay"]
        self.s1_weight_decay: float = train_params["s1_weight_decay"]
        self.odds_weight_decay: float = train_params["odds_weight_decay"]
        self.S1_weight_decay: float = train_params["S1_weight_decay"]
        self.S0_weight_decay: float = train_params["S0_weight_decay"]
        self.y_weight_decay: float = train_params["y_weight_decay"]
        self.y1_weight_decay: float = train_params["y1_weight_decay"]
        self.lam_y: float = train_params["lam_y"]
        self.selection_weight_decay: float = train_params["selection_weight_decay"]
        self.z_dim: float = train_params["z_dim"]
        self.z_ratio: float = train_params["z_ratio"]

        # build networks
        self.treatment_net: nn.Module = networks[0]
        self.instrumental_net: nn.Module = networks[1]
        self.selection_net: nn.Module = networks[2]
        self.covariate_net: Optional[nn.Module] = networks[3]
        self.phis_net: nn.Module = networks[4]
        self.phit_net: nn.Module = networks[5]
        self.odds_net: nn.Module = networks[6]
        self.S_net: nn.Module = networks[7]
        self.s_net: nn.Module = networks[8]
        self.y_net: nn.Module = networks[9]
        self.y1_net: nn.Module = networks[10]
        self.h2_net: nn.Module = networks[11]

        if self.gpu_flg:
            self.treatment_net.to("cuda:0")
            self.instrumental_net.to("cuda:0")
            if self.covariate_net is not None:
                self.covariate_net.to("cuda:0")
            self.selection_net.to("cuda:0")
            self.phis_net.to("cuda:0")
            self.phit_net.to("cuda:0")
            self.odds_net.to("cuda:0")
            self.S_net.to("cuda:0")
            self.s_net.to("cuda:0")
            self.y_net.to("cuda:0")
            self.y1_net.to("cuda:0")
            self.h2_net.to("cuda:0")
        
        # 创建adam优化器
        self.treatment_opt = torch.optim.Adam(self.treatment_net.parameters(),
                                              weight_decay=self.treatment_weight_decay)
        self.instrumental_opt = torch.optim.Adam(self.instrumental_net.parameters(), 
                                                 weight_decay=self.instrumental_weight_decay)
        self.phist_opt = torch.optim.Adam(chain(self.phis_net.parameters(), self.phit_net.parameters(), self.s_net.parameters()), 
                                        weight_decay=0)

        if self.covariate_net:
            self.covariate_opt = torch.optim.Adam(self.covariate_net.parameters(),
                                                  weight_decay=self.covariate_weight_decay)

    def train(self, rand_seed: int = 42, verbose: int = 0) -> tuple[numpy.ndarray, numpy.ndarray]:
        """
        Parameters
        ----------
        rand_seed: int
            random seed
        verbose : int
            Determine the level of logging
        Returns
        -------
        oos_result : float
            The performance of model evaluated by oos
        """
        global exp_num
        exp_num = exp_num + 1
        writer = SummaryWriter()
        
        filename = None
        if len(sys.argv) == 2:
            filename = sys.argv[1]

        best_loss = 10000
        best_res = None
        train_data, unselected_train_data, test_data, unselected_test_data = demand(Config.sample_num * 10, self.z_dim, self.z_ratio, rand_seed)
        for idx in range(self.epoch):
            
            train_1st_t, train_2nd_t, train_3rd_t = concat_dataset(train_data,
                                                                unselected_train_data), train_data, unselected_train_data
            train_1st_t = TrainDataSetTorch.from_numpy(train_1st_t)
            train_2nd_t = TrainDataSetTorch.from_numpy(train_2nd_t)
            train_3rd_t = TrainDataSetTorch.from_numpy(train_3rd_t)
            test_data_t = TestDataSetTorch.from_numpy(test_data)
            unselected_test_data_t = TestDataSetTorch.from_numpy(unselected_test_data)

            if self.gpu_flg:
                train_1st_t = train_1st_t.to_gpu()
                train_2nd_t = train_2nd_t.to_gpu()
                train_3rd_t = train_3rd_t.to_gpu()
                test_data_t = test_data_t.to_gpu()
                unselected_test_data_t = unselected_test_data_t.to_gpu()

            if idx == 0:
                self.lam1 *= train_1st_t[0].size()[0]
                self.lam2 *= train_2nd_t[0].size()[0]
                self.lam3 *= train_3rd_t[0].size()[0]
                self.lam_y = train_2nd_t[0].size()[0]

            for tt in range(self.n_epoch):
                self.stage1_update(train_1st_t, tt, writer)
                self.update_covariate_net(train_1st_t, train_2nd_t, tt, writer)
                self.stage2_update(train_1st_t, train_2nd_t, tt, writer)
            mdl = TSSIModel(self.treatment_net, self.instrumental_net, self.selection_net, self.covariate_net, 
                            self.odds_net, self.phis_net, self.phit_net, self.S_net, self.s_net, self.y_net, self.y1_net, self.h2_net, 
                            self.add_stage1_intercept, self.add_stage2_intercept,
                            self.odds_iter, self.selection_weight_decay,
                            self.odds_weight_decay, self.y_weight_decay, self.y1_weight_decay, self.lam_y, self.distance_dim)
            mdl.fit_t(train_1st_t, train_2nd_t, train_3rd_t, self.lam1, self.lam2, self.lam3) # shadow variable -> Selection Bias

            if self.gpu_flg:
                torch.cuda.empty_cache()

            oos_loss: numpy.ndarray = mdl.evaluate_t(test_data_t)
            unselected_loss: numpy.ndarray = mdl.evaluate_t(unselected_test_data_t)
            res = oos_loss, unselected_loss

            if res[0] + res[1] < best_loss:
                best_loss = res[0] + res[1]
                best_res = res

        return best_res

    def stage1_update(self, train_1st_t: TrainDataSetTorch, epoch: int, writer: SummaryWriter):
        self.treatment_net.train(False)
        self.instrumental_net.train(False)
        self.phis_net.train(True)
        self.phit_net.train(True)
        self.s_net.train(True)
        
        bce_func = nn.BCELoss()
        if self.covariate_net:
            self.covariate_net.train(False)
        mis_estimator = eval("CLUB")(self.distance_dim, self.distance_dim, self.distance_dim * 2) 
        mit_estimator = eval("CLUB")(self.distance_dim, self.distance_dim, self.distance_dim * 2) 
        mist_estimator = eval("CLUB")(self.distance_dim, self.distance_dim, self.distance_dim * 2) 
        if self.gpu_flg:
            mis_estimator = mis_estimator.to("cuda:0")
            mit_estimator = mit_estimator.to("cuda:0")
            mist_estimator = mist_estimator.to("cuda:0")
        mis_optimizer = torch.optim.Adam(mis_estimator.parameters(), lr=1e-4)
        mit_optimizer = torch.optim.Adam(mit_estimator.parameters(), lr=1e-4)
        mist_optimizer = torch.optim.Adam(mist_estimator.parameters(), lr=1e-4)
        treatment_feature = self.treatment_net(train_1st_t.treatment).detach() 
        for i in range(self.stage1_iter):
            for j in range(self.mi_iter): 
                mis_estimator.train(True)
                mit_estimator.train(True)
                phis_feature = self.phis_net(train_1st_t.instrumental) 
                phit_feature = self.phit_net(train_1st_t.instrumental) 
                mis_loss = mis_estimator.learning_loss(phis_feature, train_1st_t.treatment) 
                mit_loss = mit_estimator.learning_loss(phit_feature, train_1st_t.selection) 
                mis_optimizer.zero_grad()
                mit_optimizer.zero_grad()
                mis_loss.backward()
                mit_loss.backward()
                mis_optimizer.step()
                mit_optimizer.step()
            for j in range(self.mi_iter):
                mist_estimator.train(True)
                phis_feature = self.phis_net(train_1st_t.instrumental) 
                phit_feature = self.phit_net(train_1st_t.instrumental) 
                mist_loss = mis_estimator.learning_loss(phis_feature, phit_feature) 
                mist_optimizer.zero_grad()
                mist_loss.backward()
                mist_optimizer.step()
            mis_estimator.train(False) 
            mit_estimator.train(False) 
            mist_estimator.train(False)
            phis_feature = self.phis_net(train_1st_t.instrumental)
            phit_feature = self.phit_net(train_1st_t.instrumental)
            # loss_s
            s_pred = self.s_net(torch.cat((phis_feature, train_1st_t.covariate), 1))
            loss_s = bce_func(s_pred, train_1st_t.selection) 
            feature_t = TSSIModel.augment_stage1_feature(phit_feature, self.add_stage1_intercept) 
            loss_t = linear_reg_loss(treatment_feature, feature_t, self.lam1)
            loss = loss_s + loss_t + self.lam4 * mis_estimator(phis_feature, train_1st_t.treatment) + self.lam4 * mit_estimator(phit_feature, train_1st_t.selection) + self.lam4 * mist_estimator(phis_feature, phit_feature)
            self.phis_net.zero_grad()
            self.phis_net.zero_grad()
            loss.backward()
            self.phist_opt.step()

        self.instrumental_net.train(True)
        for i in range(self.stage1_iter):
            self.instrumental_opt.zero_grad()
            phit_feature = self.phit_net(train_1st_t.instrumental)
            instrumental_feature = self.instrumental_net(torch.cat((phit_feature, train_1st_t.covariate), 1))
            feature_t = TSSIModel.augment_stage1_feature(instrumental_feature, self.add_stage1_intercept) 
            loss_t = linear_reg_loss(treatment_feature, feature_t, self.lam1)
            loss_t.backward() 
            self.instrumental_opt.step() 

        
    def stage2_update(self, train_1st_t: TrainDataSetTorch, train_2nd_t: TrainDataSetTorch, epoch: int, writer: SummaryWriter):
        self.treatment_net.train(True)
        self.instrumental_net.train(False)
        self.y_net.train(False)
        self.y1_net.train(False)
        self.selection_net.train(False)

        if self.covariate_net:
            self.covariate_net.train(False)

        # have instrumental features
        train_1st_t_phit_feature = self.phit_net(train_1st_t.instrumental)
        train_2nd_t_phit_feature = self.phit_net(train_2nd_t.instrumental)
        instrumental_1st_feature = self.instrumental_net(torch.cat((train_1st_t_phit_feature, train_1st_t.covariate), 1)).detach()
        instrumental_2nd_feature = self.instrumental_net(torch.cat((train_2nd_t_phit_feature, train_2nd_t.covariate), 1)).detach()
        phis_2nd_feature = self.phis_net(train_2nd_t.instrumental).detach()
        covariate_2nd_feature = None
        # have covariate features
        if self.covariate_net:
            covariate_2nd_feature = self.covariate_net(train_2nd_t.covariate).detach()
            covariate_1st_feature = self.covariate_net(train_1st_t.covariate).detach()

        for i in range(self.stage2_iter):
            self.treatment_opt.zero_grad()
            treatment_1st_feature = self.treatment_net(train_1st_t.treatment)
            treatment_2nd_feature = self.treatment_net(train_2nd_t.treatment)
            res = TSSIModel.fit_2sls(treatment_1st_feature,
                                     treatment_2nd_feature,
                                     instrumental_1st_feature,
                                     instrumental_2nd_feature,
                                     phis_2nd_feature,
                                     covariate_1st_feature,
                                     covariate_2nd_feature,
                                     train_2nd_t.outcome,
                                     self.lam1, self.lam2,
                                     self.add_stage1_intercept,
                                     self.add_stage2_intercept)
            loss = res["stage2_loss"]
            loss.backward()
            self.treatment_opt.step()

    def update_covariate_net(self, train_1st_data: TrainDataSetTorch, train_2nd_data: TrainDataSetTorch, epoch: int, writer: SummaryWriter):
        # have instrumental features
        self.selection_net.train(False)
        self.instrumental_net.train(False)
        train_1st_t_phit_feature = self.phit_net(train_1st_data.instrumental)
        train_2nd_t_phit_feature = self.phit_net(train_2nd_data.instrumental)
        instrumental_1st_feature = self.instrumental_net(torch.cat((train_1st_t_phit_feature, train_1st_data.covariate), 1)).detach()
        instrumental_2nd_feature = self.instrumental_net(torch.cat((train_2nd_t_phit_feature, train_2nd_data.covariate), 1)).detach()
        phis_2nd_feature = self.phis_net(train_2nd_data.instrumental).detach()
        self.treatment_net.train(False)
        treatment_1st_feature = self.treatment_net(train_1st_data.treatment).detach()
        treatment_2nd_feature = self.treatment_net(train_2nd_data.treatment).detach()

        feature = TSSIModel.augment_stage1_feature(instrumental_1st_feature, self.add_stage1_intercept)
        stage1_weight = fit_linear(treatment_1st_feature, feature, self.lam1) 

        feature = TSSIModel.augment_stage1_feature(instrumental_2nd_feature, self.add_stage1_intercept)
        predicted_treatment_feature = linear_reg_pred(feature, stage1_weight).detach() 

        self.covariate_net.train(True)
        self.phis_net.train(False)
        self.phit_net.train(False)
        for i in range(self.covariate_iter): 
            self.covariate_opt.zero_grad()
            covariate_feature = self.covariate_net(train_2nd_data.covariate)
            # stage2 - y1 regression
            feature = TSSIModel.augment_stage_y1_feature(predicted_treatment_feature,
                                                         phis_2nd_feature,
                                                         covariate_feature,
                                                         self.add_stage2_intercept)
            loss = linear_reg_loss(train_2nd_data.outcome, feature, self.lam2)
            loss.backward()
            self.covariate_opt.step()
