import numpy as np 
import torch
from torch.autograd import Variable

from neuralfaults.utils.helpers import (initialize_metrics,
                                    get_mean_metrics,
                                    compute_metrics)

class GAINRunner():
    def __init__(self, device, model_g, model_d,
                optimizer_g, optimizer_d, criterion_g, criterion_d,
                train_loader, val_loader, opt):
        self.opt = opt
        self.p_hint = 0.9
        self.device = device
        self.model_g = model_g
        self.model_d = model_d
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.criterion_g = criterion_g
        self.criterion_d = criterion_d
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def set_epoch_metrics(self):
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def batch_to_gpu(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        input_tensor = Variable(input_tensor).to(self.device).float()
        output_tensor = Variable(output_tensor).to(self.device).float()

        mask_tensor = Variable(mask_tensor).to(self.device).float()
        delta_tensor = Variable(delta_tensor).to(self.device).float()
        return input_tensor, output_tensor, mask_tensor, delta_tensor

    def sample_M(self, m, c, n, p):
        A = torch.tensor(np.random.uniform(0., 1., size = [m, c, n])).to(self.device)
        B = A > p
        C = 1.*B
        return C

    def sample_Z(self, m, c, n):
        return torch.tensor(np.random.uniform(0., 1., size = [m, c, n])).to(self.device).float()

    def train_forward_backward(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        # Discriminator steps
        self.optimizer_d.zero_grad()
        org_shape = input_tensor.shape 

        noise_tensor = self.sample_Z(*org_shape)
        hint_tensor = self.sample_M(org_shape[0], org_shape[1], org_shape[2], 1-self.p_hint)
        hint_tensor = mask_tensor * hint_tensor + 0.5 * (1 - hint_tensor)

        missing_tensor = mask_tensor * input_tensor + (1 - mask_tensor) * noise_tensor

        generated_tensor = self.model_g(input_tensor, missing_tensor, mask_tensor)
        discriminate_prob = self.model_d(input_tensor, mask_tensor, generated_tensor, hint_tensor)
        loss_d = self.criterion_d(discriminate_prob, mask_tensor) 
        loss_d.backward()
        self.optimizer_d.step()

        # Generator Steps
        self.optimizer_g.zero_grad()

        generated_tensor = self.model_g(input_tensor, missing_tensor, mask_tensor)
        discriminate_prob = self.model_d(input_tensor, mask_tensor, generated_tensor, hint_tensor)
        loss_g = ((1 - mask_tensor) * (torch.sigmoid(discriminate_prob) + 1e-8).log()).mean()/(1 - mask_tensor).sum() \
                + 10 * (self.criterion_g(input_tensor, generated_tensor))

        loss_g.backward()
        self.optimizer_g.step()

        return generated_tensor, loss_g, loss_d

    def eval_forward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        org_shape = input_tensor.shape 

        noise_tensor = self.sample_Z(*org_shape)
        hint_tensor = self.sample_M(org_shape[0], org_shape[1], org_shape[2], 1-self.p_hint)
        hint_tensor = mask_tensor * hint_tensor + 0.5 * (1 - hint_tensor)
        missing_tensor = mask_tensor * input_tensor + (1 - mask_tensor) * noise_tensor

        generated_tensor = self.model_g(input_tensor, missing_tensor, mask_tensor)
        discriminate_prob = self.model_d(input_tensor, mask_tensor, generated_tensor, hint_tensor)

        loss_g = ((1 - mask_tensor) * (torch.sigmoid(discriminate_prob) + 1e-8).log()).mean()/(1 - mask_tensor).sum() \
                + 10 * (self.criterion_g(input_tensor, generated_tensor))
        loss_d = self.criterion_d(discriminate_prob, mask_tensor)    

        return generated_tensor, loss_g, loss_d

    def train_model(self):
        self.model_g.train()
        self.model_d.train()

        for batch in self.train_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)

            prediction_tensor, loss_g, loss_d = self.train_forward_backward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.train_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)

    def eval_model(self):
        self.model_g.eval()
        self.model_d.eval()

        for batch in self.val_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)

            prediction_tensor, loss_g, loss_d = self.eval_forward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.val_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.val_metrics)
