import os
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt
import pickle
from time import time
from logging import getLogger
import torch.nn as nn
from causally.utils.utils import ensure_dir, get_local_time
from causally.trainer.AbstractTrainer import AbstractTrainer
from causally.evaluator.evaluator import SKlearnEvaluator
class SKTrainer(AbstractTrainer):


    def __init__(self, config, model):
        super(SKTrainer, self).__init__(config, model)

        self.logger = getLogger()
        self.learner = config['optimizer']
        self.learning_rate = config['learning_rate']
        self.epochs = config['epochs']
        self.eval_step = min(config['eval_step'], self.epochs)
        self.stopping_step = config['stopping_step']
        self.clip_grad_norm = config['clip_grad_norm']
        self.valid_metric_bigger = config['valid_metric_bigger']
        self.test_batch_size = config['eval_batch_size']
        self.device = config['device']
        self.checkpoint_dir = config['checkpoint_dir']
        ensure_dir(self.checkpoint_dir)
        saved_model_file = '{}-{}-{}.pkl'.format(self.config['model'],self.config['dataset'], get_local_time())
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = -1000000
        self.best_valid_result = None
        self.sklearnEvaluator = SKlearnEvaluator(config=config)

    def fit(self, train_data=None,valid_data=None):

        x,t,y,w,_,_,_ = train_data.get_data()
        self.model.calculate_loss(x,t,y,w)


    def criterion(self,x,y):

        return np.mean(np.square(x-y))

    def _in_evaluate(self,treat_data,control_data):
        treated_x, _, treated_y, _, indicator_random, factual_outcome,test_treatment = treat_data.get_data()
        _, _, control_y, _, _, _,_ = control_data.get_data()

        true_ite = treated_y - control_y

        control_t = np.zeros(len(true_ite))
        treated_t = np.ones(len(true_ite))
        preds = self.model.predict(treated_x, control_t, treated_t)

        self.sklearnEvaluator.collect(ground_truth=true_ite, prediction=preds,
                                      treatment=test_treatment, indicator_random=indicator_random,
                                      factual_outcome=factual_outcome)
        return self.sklearnEvaluator.evaluate()

    def _out_evaluate(self,treat_data,control_data):
        treated_x, _, treated_y, _, indicator_random,factual_outcome,test_treatment = treat_data.get_data()
        _, _, control_y, _,_,_,_ = control_data.get_data()

        true_ite = treated_y - control_y

        control_t = np.zeros(len(true_ite))
        treated_t = np.ones(len(true_ite))
        preds = self.model.predict(treated_x, control_t, treated_t)

        self.sklearnEvaluator.collect(ground_truth=true_ite, prediction=preds,
                                     treatment=test_treatment, indicator_random=indicator_random,
                                     factual_outcome=factual_outcome)

        return self.sklearnEvaluator.evaluate()

    def evaluate(self, test_treated_data,test_control_data, train_treated_data,train_control_data,load_best_model=True, model_file=None):
        ret = {}
        out_ret = self._out_evaluate(test_treated_data, test_control_data)
        in_ret = self._in_evaluate(train_treated_data, train_control_data)
        ret.update(out_ret)
        for key in in_ret.keys():
            ret['in_{}'.format(key.split('_')[-1])] = in_ret[key]
        return ret