import os
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt

from time import time
from logging import getLogger
import torch.nn as nn
from causally.utils.utils import ensure_dir, get_local_time, early_stopping
from causally.trainer.AbstractTrainer import AbstractTrainer
from causally.evaluator.evaluator import CausalEvaluator
from causally.data.dataloader import TorchDataLoader
class TrainingTrainer(AbstractTrainer):

    def __init__(self, config, model):
        super(TrainingTrainer, self).__init__(config, model)

        self.logger = getLogger()
        self.learner = config['optimizer']
        self.learning_rate = config['learning_rate']
        self.epochs = config['epochs']
        self.eval_step = min(config['eval_step'], self.epochs)
        self.stopping_step = config['stopping_step']
        self.clip_grad_norm = config['clip_grad_norm']
        self.valid_metric_bigger = config['valid_metric_bigger']
        self.test_batch_size = config['eval_batch_size']
        self.device = config['device']
        self.checkpoint_dir = config['checkpoint_dir']
        ensure_dir(self.checkpoint_dir)
        end = time()
        x = '-'.join(str(end).strip().split('.'))
        saved_model_file = '{}-{}-{}-{}.pth'.format(self.config['model'],self.config['dataset'], get_local_time(),x)
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = 100000000.
        self.best_valid_result = None
        self.train_loss_dict = dict()
        self.optimizer = self._build_optimizer()
        self.causalEvaluator = CausalEvaluator(config=config)

    def to_device(self,x,t,y,w):
        return x.to(self.device),t.to(self.device),y.to(self.device),w.to(self.device)

    def _build_optimizer(self):

        if self.learner.lower() == 'adam':
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'sgd':
            optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'adagrad':
            optimizer = optim.Adagrad(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'rmsprop':
            optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate)
        else:
            self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer

    @torch.no_grad()
    def _valid_epoch(self, val_data,loss_func=None):

        self.model.eval()
        n_samples = val_data.get_X_size()[0]
        loss_func = loss_func or self.model.calculate_loss
        total_loss = None
        for batch_data in val_data:
            x,t,y,w = self.to_device(batch_data[0],batch_data[1],batch_data[2],batch_data[3])
            losses = loss_func(x, t, y, w)
            total_loss = losses.item() if total_loss is None else total_loss + losses.item()

        return total_loss / n_samples




    def _train_epoch(self, train_data, loss_func=None):

        self.model.train()
        n_samples = train_data.get_X_size()[0]
        loss_func = loss_func or self.model.calculate_loss
        total_loss = None
        for batch_data in train_data:
            x,t,y,w = self.to_device(batch_data[0],batch_data[1],batch_data[2],batch_data[3])
            self.optimizer.zero_grad()
            losses = loss_func(x,t,y,w)

            total_loss = losses.item() if total_loss is None else total_loss + losses.item()

            # self._check_nan(losses)
            losses.backward()
            if self.clip_grad_norm:
                clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
            self.optimizer.step()

        return total_loss/n_samples



    def _save_checkpoint(self, epoch):

        state = {
            'config': self.config,
            'epoch': epoch,
            'cur_step': self.cur_step,
            'best_valid_score': self.best_valid_score,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'model': self.model
        }
        torch.save(state, self.saved_model_file)

    def resume_checkpoint(self, resume_file):

        resume_file = str(resume_file)
        checkpoint = torch.load(resume_file)
        self.start_epoch = checkpoint['epoch'] + 1
        self.cur_step = checkpoint['cur_step']
        self.best_valid_score = checkpoint['best_valid_score']

        # load architecture params from checkpoint
        if checkpoint['config']['model'].lower() != self.config['model'].lower():
            self.logger.warning('Architecture configuration given in config file is different from that of checkpoint. '
                                'This may yield an exception while state_dict is being loaded.')
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        message_output = 'Checkpoint loaded. Resume training from epoch {}'.format(self.start_epoch)
        self.logger.info(message_output)

    def _check_nan(self, loss):
        if torch.isnan(loss):
            raise ValueError('Training loss is nan')

    def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
        train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time)
        if isinstance(losses, tuple):
            train_loss_output = ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
        else:
            train_loss_output += 'train loss: %.4f' % losses
        return train_loss_output + ']'

    def fit(self, train_data, valid_data=None, verbose=True, saved=True):

        if saved and self.start_epoch >= self.epochs:
            self._save_checkpoint(-1)

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = 'Saving current: %s' % self.saved_model_file
                    if verbose:
                        self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                valid_start_time = time()
                valid_score = self._valid_epoch(valid_data)
                self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                    valid_score, self.best_valid_score, self.cur_step,
                    max_step=self.stopping_step, bigger=self.valid_metric_bigger)
                valid_end_time = time()
                valid_score_output = "epoch %d evaluating [time: %.2fs, valid score:%f]" % \
                                     (epoch_idx, valid_end_time - valid_start_time, valid_score)
                # valid_result_output = 'valid result: \n' + dict2str(valid_result)
                if verbose:
                    self.logger.info(valid_score_output)
                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = 'Saving current best: %s' % self.saved_model_file
                        if verbose:
                            self.logger.info(update_output)
                    # self.best_valid_result = valid_result

                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break
        return self.best_valid_score



    def get_new_data(self,test_treated_data,test_control_data,train_treated_data,train_control_data):

        test_x = torch.from_numpy(test_treated_data.dataset[self.config['column_x']].values).float().to(self.device)
        test_t = torch.from_numpy(test_treated_data.dataset['treatment'].values.reshape(-1, 1)).int().to(self.device)
        test_w = torch.from_numpy(test_treated_data.dataset['weight'].values.reshape(-1, 1)).float().to(self.device)
        test_y = torch.from_numpy(test_treated_data.dataset['factual_outcome'].values.reshape(-1, 1)).float().to(self.device)

        test_grad = self.model.generate_perturbation(test_x,test_t,test_y,test_w,self.config['perturbation'])
        test_treated_data.dataset[self.config['column_x']] = test_treated_data.dataset[self.config['column_x']] + test_grad
        test_control_data.dataset[self.config['column_x']] = test_control_data.dataset[self.config['column_x']] + test_grad


        train_x = torch.from_numpy(train_treated_data.dataset[self.config['column_x']].values).float().to(self.device)
        train_t = torch.from_numpy(train_treated_data.dataset['treatment'].values.reshape(-1, 1)).int().to(self.device)
        train_w = torch.from_numpy(train_treated_data.dataset['weight'].values.reshape(-1, 1)).float().to(self.device)
        train_y = torch.from_numpy(train_treated_data.dataset['factual_outcome'].values.reshape(-1, 1)).float().to(self.device)

        train_grad = self.model.generate_perturbation(train_x, train_t, train_y, train_w,self.config['perturbation'])
        train_treated_data.dataset[self.config['column_x']] = train_treated_data.dataset[self.config['column_x']] + train_grad
        train_control_data.dataset[self.config['column_x']] = train_control_data.dataset[self.config['column_x']] + train_grad

        test_t_data = TorchDataLoader(config=self.config,dataset=test_treated_data.dataset,batch_size=self.config['eval_batch_size'],shuffle=False)
        test_c_data = TorchDataLoader(config=self.config,dataset=test_control_data.dataset,batch_size=self.config['eval_batch_size'],shuffle=False)
        train_t_data = TorchDataLoader(config=self.config,dataset=train_treated_data.dataset,batch_size=self.config['eval_batch_size'],shuffle=False)
        train_c_data = TorchDataLoader(config=self.config,dataset=train_control_data.dataset,batch_size=self.config['eval_batch_size'],shuffle=False)

        return test_t_data,test_c_data,train_t_data,train_c_data

    @torch.no_grad()
    def evaluate(self,train_data, load_best_model=True, model_file=None):

        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        train_x = torch.from_numpy(train_data.dataset[self.config['column_x']].values).float().to(self.device)
        train_t = torch.from_numpy(train_data.dataset['treatment'].values.reshape(-1, 1)).int().to(self.device)
        # train_y = torch.from_numpy(train_data.dataset['yf'].values.reshape(-1, 1)).float().to(self.device)
        pred = self.model.get_predict_yf(train_x,train_t)
        train_data.dataset['yf'] = pred
        train_new_data = TorchDataLoader(config=self.config, dataset=train_data.dataset,
                                       batch_size=self.config['train_batch_size'], shuffle=True)

        return train_new_data
