import numpy as np 
import torch
from torch.autograd import Variable

from neuralfaults.utils.helpers import (initialize_metrics,
                                    get_mean_metrics,
                                    compute_metrics)

class E2ERunner():
    def __init__(self, device, model_g, model_d,
                 pretrain_optim, optimizer_g, optimizer_d, 
                 criterion_g, criterion_d,
                train_loader, val_loader, opt):
        self.opt = opt
        self.p_hint = 0.9
        self.device = device
        self.model_g = model_g
        self.model_d = model_d
        self.pretrain_optim = pretrain_optim
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.criterion_g = criterion_g
        self.criterion_d = criterion_d
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def set_epoch_metrics(self):
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def batch_to_gpu(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        input_tensor = Variable(input_tensor).to(self.device).float()
        output_tensor = Variable(output_tensor).to(self.device).float()

        mask_tensor = Variable(mask_tensor).to(self.device).float()
        delta_tensor = Variable(delta_tensor).to(self.device).float()
        return input_tensor, output_tensor, mask_tensor, delta_tensor

    def sample_M(self, m, c, n, p):
        A = torch.tensor(np.random.uniform(0., 1., size = [m, c, n])).to(self.device)
        B = A > p
        C = 1.*B
        return C

    def sample_ita(self, m, c, n):
        return torch.tensor(np.random.uniform(0., 0.01, size = [m, c, n])).to(self.device).float()


    def train_forward_backward(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        # Discriminator steps
        self.optimizer_d.zero_grad()
        self.optimizer_g.zero_grad()

        org_shape = input_tensor.shape 

        discriminate_tensor = self.model_d(input_tensor, mask_tensor, delta_tensor)
        discriminate_prob = torch.sigmoid(discriminate_tensor)
        
        ita_tensor = self.sample_ita(*org_shape)
        generated_tensor = self.model_g(input_tensor, mask_tensor, ita_tensor, delta_tensor)
        fake_tensor = self.model_d(generated_tensor, mask_tensor, delta_tensor)
        fake_prob = torch.sigmoid(fake_tensor)

        l2_loss = ((generated_tensor - input_tensor)**2).mean()
        loss_real = discriminate_prob.mean()
        loss_fake = discriminate_prob.mean()
        loss_g = -1 * loss_fake +  50 * l2_loss 
        loss_d = loss_real - loss_g

        # generated_tensor = (1 - mask_tensor) * generated_tensor + input_tensor * mask_tensor

        loss_d.backward(retain_graph=True)
        loss_g.backward()

        self.optimizer_d.step()
        self.optimizer_g.step()

        return generated_tensor, loss_g, loss_d

    def eval_forward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        org_shape = input_tensor.shape 

        discriminate_tensor = self.model_d(input_tensor, mask_tensor, delta_tensor)
        discriminate_prob = torch.sigmoid(discriminate_tensor)
        
        ita_tensor = self.sample_ita(*org_shape)
        generated_tensor = self.model_g(input_tensor, mask_tensor, ita_tensor, delta_tensor)
        fake_tensor = self.model_d(generated_tensor, mask_tensor, delta_tensor)
        fake_prob = torch.sigmoid(fake_tensor)

        l2_loss = ((generated_tensor - input_tensor)**2).mean() 
        loss_real = discriminate_prob.mean()
        loss_fake = discriminate_prob.mean()
        loss_g = -1 * loss_fake +  50 * l2_loss 
        loss_d = loss_real - loss_g

        # generated_tensor = (1 - mask_tensor) * generated_tensor + input_tensor * mask_tensor

        return generated_tensor, loss_g, loss_d

    def pretrian_model(self):
        self.model_g.train()

        for batch in self.train_loader:
            self.pretrain_optim.zero_grad()

            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                    output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            ita_tensor = self.sample_ita(*mask_tensor.shape)
            prediction_tensor = self.model_g(input_tensor, mask_tensor, ita_tensor, delta_tensor)
            pretrain_loss = ((prediction_tensor - input_tensor)**2).mean()
            pretrain_loss.backward()
            self.pretrain_optim.step()

            compute_metrics(self.train_metrics, pretrain_loss, prediction_tensor,
                        output_tensor, self.opt)

            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)
                                    
    def train_model(self):
        self.model_g.train()
        self.model_d.train()
                
        for batch in self.train_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            prediction_tensor, loss_g, loss_d = self.train_forward_backward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.train_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)

    def eval_model(self):
        self.model_g.eval()
        self.model_d.eval()

        for batch in self.val_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)
            mask_tensor = 1 - mask_tensor
            prediction_tensor, loss_g, loss_d = self.eval_forward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.val_metrics, loss_g, prediction_tensor,
                            output_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.val_metrics)
