import numpy as np 
import torch
from torch.autograd import Variable

from neuralfaults.utils.helpers import (initialize_metrics,
                                    get_mean_metrics,
                                    compute_metrics)

class GANRunner():
    def __init__(self, device, model_z, model_g, model_d,
                 optimizer_impute, pretrain_optim, optimizer_g, optimizer_d, 
                 criterion_g, criterion_d,
                train_loader, val_loader, opt):
        self.opt = opt
        self.p_hint = 0.9
        self.device = device
        self.model_z = model_z
        self.model_g = model_g
        self.model_d = model_d
        self.pretrain_optim = pretrain_optim
        self.optimizer_impute = optimizer_impute
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.criterion_g = criterion_g
        self.criterion_d = criterion_d
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def set_epoch_metrics(self):
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def batch_to_gpu(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        input_tensor = Variable(input_tensor).to(self.device).float()
        output_tensor = Variable(output_tensor).to(self.device).float()

        mask_tensor = Variable(mask_tensor).to(self.device).float()
        delta_tensor = Variable(delta_tensor).to(self.device).float()
        return input_tensor, output_tensor, mask_tensor, delta_tensor

    def sample_M(self, m, c, n, p):
        A = torch.tensor(np.random.uniform(0., 1., size = [m, c, n])).to(self.device)
        B = A > p
        C = 1.*B
        return C

    def sample_Z(self, m, z_dim):
        return torch.tensor(np.random.uniform(0., 1., size = [m, z_dim])).to(self.device).float()


    def train_forward_backward(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        # Discriminator steps
        self.optimizer_d.zero_grad()
        self.optimizer_g.zero_grad()

        org_shape = input_tensor.shape 

        discriminate_tensor = self.model_d(input_tensor, mask_tensor, delta_tensor)
        discriminate_prob = torch.sigmoid(discriminate_tensor)
        
        z_tensor = self.sample_Z(org_shape[0], 64)
        generated_tensor = self.model_g(input_tensor, mask_tensor, delta_tensor, z_tensor, gen=True)
        fake_tensor = self.model_d(generated_tensor, mask_tensor, delta_tensor)
        fake_prob = torch.sigmoid(fake_tensor)

        z_tensor = self.model_z(z_tensor)
        imputed_tensor = self.model_g(input_tensor, mask_tensor, delta_tensor, z_tensor, gen=True)
        imputed_fake_tensor = self.model_d(imputed_tensor, mask_tensor, delta_tensor)
        imputed_fake_prob = torch.sigmoid(imputed_fake_tensor)

        impute_loss = ((imputed_tensor- input_tensor)**2).mean() - 0.15 * imputed_fake_prob.mean()
        
        loss_d = fake_prob.mean() + discriminate_prob.mean()
        loss_g = -1 * fake_prob.mean()

        generated_tensor = imputed_tensor #(1 - mask_tensor) * imputed_tensor + input_tensor

        impute_loss.backward(retain_graph=True)
        loss_d.backward(retain_graph=True)
        loss_g.backward()

        self.optimizer_impute.step()
        self.optimizer_d.step()
        self.optimizer_g.step()

        return generated_tensor, loss_g, loss_d, impute_loss

    def eval_forward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        org_shape = input_tensor.shape 

        discriminate_tensor = self.model_d(input_tensor, mask_tensor, delta_tensor)
        discriminate_prob = torch.sigmoid(discriminate_tensor)
        
        z_tensor = self.sample_Z(org_shape[0], 64)
        generated_tensor = self.model_g(input_tensor, mask_tensor, delta_tensor, z_tensor, gen=True)
        fake_tensor = self.model_d(generated_tensor, mask_tensor, delta_tensor)
        fake_prob = torch.sigmoid(fake_tensor)

        z_tensor = self.model_z(z_tensor)
        imputed_tensor = self.model_g(input_tensor, mask_tensor, delta_tensor, z_tensor, gen=True)
        imputed_fake_tensor = self.model_d(imputed_tensor, mask_tensor, delta_tensor)
        imputed_fake_prob = torch.sigmoid(imputed_fake_tensor)

        impute_loss = ((imputed_tensor - input_tensor)**2).mean() - 0.15 * imputed_fake_prob.mean()
        
        loss_d = fake_prob.mean() + discriminate_prob.mean()
        loss_g = -1 * fake_prob.mean()

        generated_tensor = imputed_tensor #(1 - mask_tensor) * imputed_tensor + input_tensor

        return generated_tensor, loss_g, loss_d, impute_loss

    def pretrian_model(self):
        self.model_g.train()

        for batch in self.train_loader:
            self.pretrain_optim.zero_grad()

            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                    output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            prediction_tensor = self.model_g(input_tensor, mask_tensor, delta_tensor)
            pretrain_loss = ((prediction_tensor - input_tensor)**2).mean()
            pretrain_loss.backward()
            self.pretrain_optim.step()

            compute_metrics(self.train_metrics, pretrain_loss, prediction_tensor,
                        output_tensor, self.opt)

            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)
                                    
    def train_model(self):
        self.model_g.train()
        self.model_d.train()
                
        for batch in self.train_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            prediction_tensor, loss_g, loss_d, impute_loss = self.train_forward_backward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.train_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)

    def eval_model(self):
        self.model_g.eval()
        self.model_d.eval()

        for batch in self.val_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            prediction_tensor, loss_g, loss_d, impute_loss = self.eval_forward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.val_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.val_metrics)
