import os
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt

from time import time
from logging import getLogger
import torch.nn as nn
from causally.utils.utils import ensure_dir, get_local_time, early_stopping
from causally.trainer.AbstractTrainer import AbstractTrainer

class standard_trainer(AbstractTrainer):

    def __init__(self, config, model):
        super(standard_trainer, self).__init__(config, model)

        self.logger = getLogger(config['logfilename'])
        self.learner = config['optimizer']
        self.learning_rate = config['learning_rate']
        self.epochs = config['epochs']
        self.eval_step = min(config['eval_step'], self.epochs)
        self.stopping_step = config['stopping_step']
        self.clip_grad_norm = config['clip_grad_norm']
        self.valid_metric_bigger = config['valid_metric_bigger']
        self.test_batch_size = config['eval_batch_size']
        self.device = config['device']
        self.checkpoint_dir = config['checkpoint_dir']
        ensure_dir(self.checkpoint_dir)
        self.saved_name = '{}-{}-{}-{}-{}'.format('IL',self.config['start_order'],self.config['model'],self.config['dataset'],self.config['exp_info']) # , get_local_time()
        saved_model_file = self.saved_name+'.pth'
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)

        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = -1000000
        self.best_valid_result = None
        self.train_loss_dict = dict()
        self.optimizer = self._build_optimizer()

    def to_device(self,x,t,y,w):
        return x.to(self.device),t.to(self.device),y.to(self.device),w.to(self.device)

    def _build_optimizer(self):

        if self.learner.lower() == 'adam':
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'sgd':
            optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'adagrad':
            optimizer = optim.Adagrad(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'rmsprop':
            optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate)
        else:
            self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

    @torch.no_grad()
    def _valid_epoch(self, val_data, my_val_data):

        self.model.eval()
        treated_data, control_data = val_data['treated'], val_data['control']
        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2, true_mean_ate, pred_mean_ate = None, None, None

        for batch_treat, batch_control, batch_my_val in zip(treated_data, control_data, my_val_data):
            ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)
            treat_covariates = batch_treat[0].to(self.device)
            control_covariates = batch_control[0].to(self.device)
            true_treatment = batch_my_val[1].to(self.device)
            true_yf = batch_my_val[2].to(self.device)

            if self.config['model'] == 'XNet':
                tau1, tau0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                preds =  (1-ps) * tau1 + ps * tau0
                pred_yf = preds
                yf_score = criterion(pred_yf, true_yf)
            elif self.config['model'] == 'RNet':
                preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                pred_yf = preds
                yf_score = criterion(pred_yf, ground_truth)
            else:
                preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                        - self.model(control_covariates, batch_control[1].to(self.device))

                pred_mu0 = self.model(control_covariates, batch_control[1].to(self.device))
                pred_mu1 = self.model(treat_covariates, batch_treat[1].to(self.device))
                pred_yf = torch.where(true_treatment == 1, pred_mu1, pred_mu0)
                yf_score = criterion(pred_yf, true_yf)

            hat_ite = criterion(preds, ground_truth)
            pehe2 = hat_ite if pehe2 is None else pehe2 + hat_ite

            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate + hat_ate

            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

        pehe2 = pehe2 / n_samples
        pred_mean_ate = pred_mean_ate / n_samples
        true_mean_ate = true_mean_ate / n_samples
        yf_score = yf_score / n_samples

        ate = torch.abs(pred_mean_ate - true_mean_ate).item()
        pehe = np.sqrt(pehe2.item())

        return ate, pehe, yf_score

    def _train_epoch(self, train_data, loss_func=None):

        self.model.train()
        n_samples = train_data.get_X_size()[0]
        loss_func = loss_func or self.model.calculate_loss
      
        total_loss = None
        for x,t,y,w in train_data:
            x,t,y,w = self.to_device(x,t,y,w)
            self.optimizer.zero_grad()
            losses = loss_func(x,t,y,w)

            if isinstance(losses, tuple):
                loss = sum(losses)
                loss_tuple = tuple(per_loss.item() for per_loss in losses)
                total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple)))
            else:
                loss = losses
                total_loss = losses.item() if total_loss is None else total_loss + losses.item()

            self._check_nan(loss)
            loss.backward()
            if self.clip_grad_norm:
                clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
            self.optimizer.step()

        return total_loss/n_samples

    def _save_checkpoint(self, epoch):
        state = {
            'config': self.config,
            'epoch': epoch,
            'cur_step': self.cur_step,
            'best_valid_score': self.best_valid_score,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'model': self.model
        }
        torch.save(state, self.saved_model_file)


    def fit(self, train_data, valid_data=None, my_val_data = None,
            test_treated_data=None,test_control_data=None,
             verbose=True, saved=True):

        train_loss_list = []
        valid_loss_list = []
        
        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = 'Saving current: %s' % self.saved_model_file
                    if verbose:
                        self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                if self.config['model'] == 'XNet' or self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    update_flag = True
                    if epoch_idx > self.config['no_validation_epoch']:
                        stop_flag = True
                    else:
                        stop_flag = False
                else:
                    valid_start_time = time()
                    ate, pehe, yf_score = self._valid_epoch(valid_data, my_val_data)
                    if self.config['valid_metric_criterion'] == 'factual':
                        valid_score = -yf_score # -yf_score #
                    elif self.config['valid_metric_criterion'] == 'pehe':
                        valid_score = -pehe
                    self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                        valid_score, self.best_valid_score, self.cur_step,
                        max_step=self.stopping_step, bigger=self.valid_metric_bigger)
                    valid_end_time = time()
                    valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \
                                        (epoch_idx, valid_end_time - valid_start_time, -valid_score)

                    train_loss_list.append(self.train_loss_dict[epoch_idx])
                    valid_loss_list.append(-valid_score)

                    if verbose:
                        self.logger.info(valid_score_output)

                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = 'Saving current best: %s' % self.saved_model_file
                        if verbose:
                            self.logger.info(update_output)


                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break

        return -self.best_valid_score


    @torch.no_grad()
    def evaluate(self, treated_data,control_data, load_best_model=True, model_file=None):

        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        for batch_treat,batch_control in zip(treated_data, control_data):

            ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)
            treat_covariates = batch_treat[0].to(self.device) 
            control_covariates = batch_control[0].to(self.device) 

            if self.config['model'] == 'XNet':
                y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                preds = ps * y0 + (1 - ps) * y1
            elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                preds = self.model(treat_covariates, batch_treat[1].to(self.device))
            else:
                preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                        - self.model(control_covariates, batch_control[1].to(self.device))

            hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
            pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite

            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate

            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

        pehe2 = pehe2 / n_samples
        pred_mean_ate = pred_mean_ate / n_samples
        true_mean_ate = true_mean_ate / n_samples


        ate = torch.abs(pred_mean_ate-true_mean_ate).item()
        pehe = np.sqrt(pehe2.item())

        return {'pehe': pehe, 'ate': ate}

    @torch.no_grad()
    def adversarial_measure_evaluate(self, treated_data, control_data, measure_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        measure = measure_ratio
        for mea in measure:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)
                
                covariates_perturbed = self._get_measure_data(batch_treat[0], mea)
                treat_covariates = batch_treat[0].to(self.device)  + covariates_perturbed 
                control_covariates = batch_control[0].to(self.device) + covariates_perturbed

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite


            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples

            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())
            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res

    @torch.no_grad()
    def adversarial_missing_evaluate(self, treated_data, control_data, missing_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        missing = missing_ratio

        for mea in missing:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)

                covariates = self._get_missing_data(batch_treat[0], mea)
                treat_covariates = covariates.to(self.device)  
                control_covariates = covariates.to(self.device) 

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite


            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples


            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())
            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res

    @torch.no_grad()
    def adversarial_hidden_evaluate(self, treated_data, control_data, hidden_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        hidden = hidden_ratio
        for mea in hidden:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)
                
                # hidden groud_truth 
                ground_truth = self._get_hidden_data(batch_treat[0], mea)

                treat_covariates = batch_treat[0].to(self.device) 
                control_covariates = batch_control[0].to(self.device) 

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite


            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples


            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())

            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res