from numpy.lib.twodim_base import mask_indices
import torch
from torch.autograd import Variable

from neuralfaults.utils.helpers import (initialize_metrics,
                                    get_mean_metrics,
                                    compute_metrics)

class Runner():
    def __init__(self, device, model,
                optimizer, criterion,
                train_loader, val_loader, opt):
        self.opt = opt
        self.device = device
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def set_epoch_metrics(self):
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def batch_to_gpu(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        input_tensor = Variable(input_tensor).to(self.device).float()
        output_tensor = Variable(output_tensor).to(self.device).float()

        if len(self.opt.fail_quants):
            mask_tensor = Variable(mask_tensor).to(self.device).float()
            delta_tensor = Variable(delta_tensor).to(self.device).float()
            return input_tensor, output_tensor, mask_tensor, delta_tensor

        return input_tensor, output_tensor

    def train_forward_backward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        # Zero the gradient
        self.optimizer.zero_grad()

        # Get model predictions, calculate loss, backprop
        if len(self.opt.fail_quants):
            if 'brits' in self.opt.impute_model:
                loss, prediction_tensor = self.model(input_tensor, mask_tensor, delta_tensor)
                loss += self.criterion(prediction_tensor, output_tensor)
                # prediction_tensor = mask_tensor * prediction_tensor
                # output_tensor = mask_tensor * output_tensor
            else:    
                prediction_tensor = self.model(input_tensor, mask_tensor, delta_tensor)
                # prediction_tensor = mask_tensor * prediction_tensor
                # output_tensor = mask_tensor * output_tensor
                loss = self.criterion(prediction_tensor, output_tensor)
        else:
            prediction_tensor = self.model(input_tensor)
            loss = self.criterion(prediction_tensor, output_tensor)
        # print ("########## A LOOP ###########")
        # for param in self.model.parameters():
        #     print (param.min(), param.max())
        # print (prediction_tensor.min(), prediction_tensor.max())
        
        loss.backward()

        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
        self.optimizer.step()
        return prediction_tensor, loss

    def eval_forward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        # Get predictions and calculate loss
        if len(self.opt.fail_quants):
            if 'brits' in self.opt.impute_model:
                loss, prediction_tensor = self.model(input_tensor, mask_tensor, delta_tensor)
                loss += self.criterion(prediction_tensor, output_tensor)
                # prediction_tensor = mask_tensor * prediction_tensor
                # output_tensor = mask_tensor * output_tensor
            else:    
                prediction_tensor = self.model(input_tensor, mask_tensor, delta_tensor)
                # prediction_tensor = mask_tensor * prediction_tensor
                # output_tensor = mask_tensor * output_tensor
                loss = self.criterion(prediction_tensor, output_tensor)
        else:
            prediction_tensor = self.model(input_tensor)
            loss = self.criterion(prediction_tensor, output_tensor)
            
        return prediction_tensor, loss

    def train_model(self):
        self.model.train()

        for batch in self.train_loader:
            if len(self.opt.fail_quants):
                input_tensor, output_tensor, mask_tensor, delta_tensor = batch
                input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                            output_tensor, mask_tensor, delta_tensor)
                mask_tensor = 1 - mask_tensor
                prediction_tensor, loss = self.train_forward_backward(input_tensor, output_tensor,
                                                    mask_tensor, delta_tensor)
                compute_metrics(self.train_metrics, loss, prediction_tensor,
                                output_tensor * mask_tensor, self.opt)
                # clear batch variables from memory
                del input_tensor, output_tensor, mask_tensor, delta_tensor
            else:
                input_tensor, output_tensor = batch
                input_tensor, output_tensor = self.batch_to_gpu(input_tensor,
                                                            output_tensor)

                prediction_tensor, loss = self.train_forward_backward(input_tensor,
                                                                      output_tensor)
                compute_metrics(self.train_metrics, loss, prediction_tensor,
                                output_tensor, self.opt)
                # clear batch variables from memory
                del input_tensor, output_tensor

        return get_mean_metrics(self.train_metrics)

    def eval_model(self):
        self.model.eval()

        for batch in self.val_loader:
            if len(self.opt.fail_quants):
                input_tensor, output_tensor, mask_tensor, delta_tensor = batch
                input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                            output_tensor, mask_tensor, delta_tensor)
                mask_tensor = 1 - mask_tensor
                prediction_tensor, loss = self.eval_forward(input_tensor, output_tensor,
                                                    mask_tensor, delta_tensor)
                compute_metrics(self.val_metrics, loss, prediction_tensor,
                                output_tensor * mask_tensor, self.opt)
                # clear batch variables from memory
                del input_tensor, output_tensor, mask_tensor, delta_tensor
            else:
                input_tensor, output_tensor = batch
                input_tensor, output_tensor = self.batch_to_gpu(input_tensor, 
                                                            output_tensor)

                prediction_tensor, loss = self.eval_forward(input_tensor, output_tensor)
                compute_metrics(self.val_metrics, loss, prediction_tensor,
                                output_tensor, self.opt)
                # clear batch variables from memory
                del input_tensor, output_tensor

        return get_mean_metrics(self.val_metrics)
