from __future__ import annotations
from typing import Dict, Any, Optional
import torch
from torch import nn
import logging
from pathlib import Path
import copy

import numpy as np

from src.models.D1SIV.nn_structure import build_extractor
# from src.models.D1SIV.monitor import D1SIVMonitor
from src.models.D1SIV.model import D1SIVModel
from src.models.D1SIV.data_loader import get_minibatch_loader

from src.data import generate_train_data, generate_test_data
from src.data.data_class import TrainDataSet, TrainDataSetTorch, TestDataSetTorch
from src.utils.pytorch_linear_reg_utils import linear_reg_loss, fit_linear, linear_reg_pred

logger = logging.getLogger()


class D1SIVTrainer(object):

    def __init__(self, data_configs: Dict[str, Any], train_params: Dict[str, Any],
                 gpu_flg: bool = False, dump_folder: Optional[Path] = None):
        self.data_config = data_configs
        self.gpu_flg = gpu_flg and torch.cuda.is_available()
        if self.gpu_flg:
            logger.info("gpu mode")
        # configure training params
        self.lam1: float = train_params["lam1"]
        self.lam2: float = train_params["lam2"]
        self.stage1_iter: int = train_params["stage1_iter"]
        self.stage2_iter: int = train_params["stage2_iter"]
        self.covariate_iter: int = train_params["covariate_iter"]
        self.n_epoch: int = train_params["n_epoch"]
        self.add_stage1_intercept = True
        self.add_stage2_intercept = True
        self.treatment_weight_decay = train_params["treatment_weight_decay"]
        self.instrumental_weight_decay = train_params["instrumental_weight_decay"]
        self.covariate_weight_decay = train_params["covariate_weight_decay"]
        self.stage1_minibatch = train_params.get("stage1_minibatch", None)
        self.stage2_minibatch = train_params.get("stage2_minibatch", None)

        # build networks
        networks = build_extractor(data_configs["data_name"])
        self.treatment_net: nn.Module = networks[0]
        self.instrumental_net: nn.Module = networks[1]
        self.covariate_net: Optional[nn.Module] = networks[2]
        if self.gpu_flg:
            self.treatment_net.to("cuda:0")
            self.instrumental_net.to("cuda:0")
            if self.covariate_net is not None:
                self.covariate_net.to("cuda:0")
        self.treatment_opt = torch.optim.Adam(self.treatment_net.parameters(),
                                              weight_decay=self.treatment_weight_decay)
        self.instrumental_opt = torch.optim.Adam(self.instrumental_net.parameters(),
                                                 weight_decay=self.instrumental_weight_decay)
        if self.covariate_net:
            self.covariate_opt = torch.optim.Adam(self.covariate_net.parameters(),
                                                  weight_decay=self.covariate_weight_decay)

        # build monitor
        self.monitor = None
        # if dump_folder is not None:
        #     self.monitor = D1SIVMonitor(dump_folder, self)

    def train(self, rand_seed: int = 1997, verbose: int = 0) -> float:
        """

        Parameters
        ----------
        rand_seed: int
            random seed
        verbose : int
            Determine the level of logging
        Returns
        -------
        oos_result : float
            The performance of model evaluated by oos
        """
        train_data = generate_train_data(rand_seed=rand_seed, **self.data_config)
        test_data = generate_test_data(**self.data_config)
        test_data_t = TestDataSetTorch.from_numpy(test_data)
        train_data_t = TrainDataSetTorch.from_numpy(train_data)
        if self.gpu_flg:
            train_data_t = train_data_t.to_gpu()
            test_data_t = test_data_t.to_gpu()

        
        if self.monitor is not None:
            new_rand_seed = np.random.randint(1e5)
            new_data_config = copy.copy(self.data_config)
            new_data_config["data_size"] = 5000
            validation_data = generate_train_data(rand_seed=new_rand_seed, **self.data_config)
            validation_data_t = TrainDataSetTorch.from_numpy(validation_data)
            if self.gpu_flg:
                validation_data_t = validation_data_t.to_gpu()
            self.monitor.configure_data(train_data_t, test_data_t, validation_data_t)

        self.lam2 *= train_data_t[0].size()[0]
        # self.lam2=0

        if self.stage2_minibatch is None:
            self.stage2_minibatch = train_data_t.treatment.shape[0]
            
        for t in range(self.n_epoch):
            stage2_loader = get_minibatch_loader(train_data_t, self.stage2_minibatch)
            for train_data_t_sub in stage2_loader:
                if self.covariate_net:
                    self.update_covariate_net(train_data_t_sub, verbose)
                self.stage2_update(train_data_t_sub, verbose)
            if self.monitor is not None:
                self.monitor.record(verbose)
            if verbose >= 1:
                logger.info(f"Epoch {t} ended")

        mdl = D1SIVModel(self.treatment_net, self.covariate_net,
                        self.add_stage1_intercept, self.add_stage2_intercept)
        mdl.fit_t(train_data_t, self.lam1, self.lam2)
        if self.gpu_flg:
            torch.cuda.empty_cache()

        oos_loss: float = mdl.evaluate_t(test_data_t).data.item()
        return oos_loss
    def MDD_loss(self,y_hat: torch.Tensor,
             y: torch.Tensor,
             z: torch.Tensor):
        e = y-y_hat
        rowmeane = torch.mean(e, axis=0)
        e_demean = e-rowmeane
        e_mat = e_demean @ (e_demean.t())
        z_mat = torch.cdist(z, z, p=2)
        mdd = -(torch.mean(e_mat * z_mat))
        return mdd
    def stage2_update(self, train_data, verbose):
        self.treatment_net.train(True)
        if self.covariate_net:
            self.covariate_net.train(False)

        instrumental_feature = train_data.instrumental.detach()

        # have covariate features
        covariate_feature = None
        if self.covariate_net:
            covariate_feature = self.covariate_net(train_data.covariate).detach()

        for i in range(self.stage2_iter):
            self.treatment_opt.zero_grad()
            treatment_feature = self.treatment_net(train_data.treatment)
            feature = D1SIVModel.augment_stage2_feature(treatment_feature,
                                                       covariate_feature,
                                                       self.add_stage2_intercept)
            
            loss1, yhat, _ = linear_reg_loss(train_data.outcome, feature, self.lam2)
            loss = 0*loss1 + self.MDD_loss(yhat, train_data.outcome, instrumental_feature)

            loss.backward()
            if verbose >= 2:
                logger.info(f"stage2 learning: {loss.item()}")
            self.treatment_opt.step()

    
    def update_covariate_net(self, train_data: TrainDataSetTorch,
                             verbose: int):
        instrumental_feature = train_data.instrumental.detach()

        self.treatment_net.train(False)
        treatment_feature = self.treatment_net(train_data.treatment).detach()

        self.covariate_net.train(True)
        for i in range(self.covariate_iter):
            self.covariate_opt.zero_grad()
            covariate_feature = self.covariate_net(train_data.covariate)
            feature = D1SIVModel.augment_stage2_feature(treatment_feature,
                                                       covariate_feature,
                                                       self.add_stage2_intercept)
            loss1, yhat, _ = linear_reg_loss(train_data.outcome, feature, self.lam2)
            loss = 0*loss1 + self.MDD_loss(yhat, train_data.outcome, instrumental_feature)
            loss.backward()
            if verbose >= 2:
                logger.info(f"update covariate: {loss.item()}")
            self.covariate_opt.step()
