import numpy as np 
import torch
from torch.autograd import Variable

from neuralfaults.utils.helpers import (initialize_metrics,
                                    get_mean_metrics,
                                    compute_metrics)

class SGANRunner():
    def __init__(self, device, model_g, model_d,
                optimizer_g, optimizer_d, criterion_g, criterion_d,
                train_loader, val_loader, opt):
        self.opt = opt
        self.p_hint = 0.9
        self.device = device
        self.model_g = model_g
        self.model_d = model_d
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.criterion_g = criterion_g
        self.criterion_d = criterion_d
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def set_epoch_metrics(self):
        self.train_metrics = initialize_metrics(self.opt)
        self.val_metrics = initialize_metrics(self.opt)

    def batch_to_gpu(self, input_tensor, output_tensor, mask_tensor, delta_tensor):
        input_tensor = Variable(input_tensor).to(self.device).float()
        output_tensor = Variable(output_tensor).to(self.device).float()

        mask_tensor = Variable(mask_tensor).to(self.device).float()
        delta_tensor = Variable(delta_tensor).to(self.device).float()
        return input_tensor, output_tensor, mask_tensor, delta_tensor

    def sample_E(self, m, c, n):
        A = torch.tensor(np.random.uniform(-0.01, 0.01, size = [m, c, n])).to(self.device).float()
        return A

    def sample_Z(self, m, c, n):
        return torch.tensor(np.random.uniform(0., 0.01, size = [m, c, n])).to(self.device).float()

    def train_forward_backward(self, input_tensor, output_tensor, mask_tensor, delta_tensor, train_g):
        # Discriminator steps
        org_shape = input_tensor.shape 
        self.optimizer_d.zero_grad()

        noise_tensor = self.sample_Z(*org_shape)
        error_tensor = self.sample_E(*org_shape)

        fake_tensor, _ = self.model_g(input_tensor, noise_tensor, mask_tensor, error_tensor)
        fake_prob = self.model_d(fake_tensor)

        real_prob = self.model_d(input_tensor)
        loss_d = -1 * real_prob.mean() + fake_prob.mean()

        loss_g = 0

        # for p in self.model_d.parameters():
            # p.data.clamp_(-0.01, 0.01)
        
        # if train_g:
        self.optimizer_g.zero_grad()

        noise_tensor = self.sample_Z(*org_shape)
        error_tensor = self.sample_E(*org_shape)

        fake_tensor, error_gen = self.model_g(input_tensor, noise_tensor, mask_tensor, error_tensor)
        fake_prob = self.model_d(fake_tensor)
        loss_g =  2 * (((fake_tensor - input_tensor)**2).mean() + ((error_gen - error_tensor)**2).mean()) -  fake_prob.mean() 

        loss_g.backward()
        self.optimizer_g.step()

        return fake_tensor, loss_g, loss_d

    def eval_forward(self, input_tensor, output_tensor, mask_tensor=None, delta_tensor=None):
        org_shape = input_tensor.shape 

        noise_tensor = self.sample_Z(*org_shape)
        error_tensor = self.sample_E(*org_shape)

        fake_tensor, _ = self.model_g(input_tensor, noise_tensor, mask_tensor, error_tensor)
        fake_prob = self.model_d(fake_tensor)

        real_prob = self.model_d(input_tensor)
        loss_d = -1 * real_prob.mean() + fake_prob.mean()

        noise_tensor = self.sample_Z(*org_shape)
        error_tensor = self.sample_E(*org_shape)

        fake_tensor, error_gen = self.model_g(input_tensor, noise_tensor, mask_tensor, error_tensor)
        fake_prob = self.model_d(fake_tensor)
        loss_g =  2 * (((fake_tensor - input_tensor)**2).mean() + ((error_gen - error_tensor)**2).mean()) -  fake_prob.mean()

        return fake_tensor, loss_g, loss_d

    def train_model(self):
        self.model_g.train()
        self.model_d.train()

        i = 0
        for batch in self.train_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)
            if i % 5 == 0:
                train_g = True
            i += 1

            prediction_tensor, loss_g, loss_d = self.train_forward_backward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor, train_g)
            compute_metrics(self.train_metrics, loss_g, prediction_tensor,
                            input_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.train_metrics)

    def eval_model(self):
        self.model_g.eval()
        self.model_d.eval()

        for batch in self.val_loader:
            input_tensor, output_tensor, mask_tensor, delta_tensor = batch
            input_tensor, output_tensor, mask_tensor, delta_tensor = self.batch_to_gpu(input_tensor,
                                                        output_tensor, mask_tensor, delta_tensor)

            prediction_tensor, loss_g, loss_d = self.eval_forward(input_tensor, output_tensor,
                                                mask_tensor, delta_tensor)
            compute_metrics(self.val_metrics, loss_g, prediction_tensor,
                            input_tensor, self.opt)
            # clear batch variables from memory
            del input_tensor, output_tensor, mask_tensor, delta_tensor

        return get_mean_metrics(self.val_metrics)
