import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning.core.lightning import LightningModule

import util


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class MLP(nn.Module):
    def __init__(self, input_len, args, use_softmax = False, verbose = True):
        super(MLP,self).__init__()

        if verbose:
            print('MLP')
        n_output = args.num_bin
        self.use_softmax = use_softmax
        num_neuron = args.num_neuron

        self.fc1 = nn.Linear(input_len, num_neuron)
        self.fc2 = nn.Linear(num_neuron, num_neuron)
        self.fc3 = nn.Linear(num_neuron, num_neuron)
        self.fc4 = nn.Linear(num_neuron, n_output)
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.fc3.weight)
        nn.init.kaiming_normal_(self.fc4.weight)

    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        if self.use_softmax:
            return torch.softmax(x, dim = 1)
        else:
            return torch.sigmoid(x)

class LightningModuleWithCustomLossComputation(LightningModule):
    def __init__(self):
        super().__init__()

    def training_step(self, batch, batch_idx):
        loss = 0.0
        for x, y in batch:
            out = self.forward(x)
            loss += self.loss_function(out, y, y[:,1], "train")
        return { 'loss': loss }

    def training_epoch_end(self, training_step_outputs):
        temp = 0.0
        count = 0
        for entry in training_step_outputs:
            temp += entry['loss']
            count += 1
        avg_loss = temp / count
        self.log('train_loss', avg_loss, on_epoch=True, on_step=False, logger=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.forward(x)
        loss = self.loss_function(out, y, y[:,1], 'val')
        return { 'loss': loss, 'batch_size': x.shape[0] }

    def validation_epoch_end(self, validation_step_outputs):
        temp = 0.0
        count = 0
        for entry in validation_step_outputs:
            temp += entry['loss']
            count += 1
        avg_loss = temp / count
        self.log('val_loss', avg_loss, on_epoch=True, on_step=False, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        out = self.forward(x)
        loss = self.loss_function(out, y, y[:,1], 'test')
        return { 'loss': loss, 'batch_size': x.shape[0] }

    def test_epoch_end(self, test_step_outputs):
        temp = 0.0
        count = 0
        for entry in test_step_outputs:
            temp += entry['loss']
            '''
            for loss_name, loss_val in entry['loss_set'].items():
                if loss_name not in temp:
                    temp[loss_name] = 0.0
                temp[loss_name] += loss_val
            '''
            count += 1
        avg_loss = temp / count
        self.log('test_loss', avg_loss, on_epoch=True, on_step=False, logger=True)

#class Softmax(nn.Module):
class Softmax(LightningModuleWithCustomLossComputation):
    def __init__(self, neural_network, loss_function, args):
        super().__init__()
        self.neural_network = neural_network
        self.loss_function = loss_function
        self.lr = args.learning_rate

    def forward(self, x, lengths=None):
        return self.neural_network(x)

    def predict(self, x, lengths=None):
        return self.neural_network(x)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

class SurvivalGame(LightningModule):
    def __init__(self, nn_f, nn_g, loss_function, args):
        super().__init__()
        self.nn_f = nn_f
        self.nn_g = nn_g
        self.loss_function = loss_function
        self.lr = args.learning_rate

    def forward(self, x, lengths=None):
        return self.nn_f(x)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        loss = 0.0
        for x, y in batch:
            out_f = self.nn_f(x)
            out_g = self.nn_g(x)
            loss += self.loss_function(out_f, out_g, y[:,0], y[:,1], 'train')
        return { 'loss': loss }

    def training_epoch_end(self, training_step_outputs):
        temp = 0.0
        count = 0
        for entry in training_step_outputs:
            temp += entry['loss']
            count += 1
        avg_loss = temp / count
        self.log('train_loss', avg_loss, on_epoch=True, on_step=False, logger=True)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out_f = self.nn_f(x)
        out_g = self.nn_g(x)
        loss = self.loss_function(out_f, out_g, y[:,0], y[:,1], 'val')
        return { 'loss': loss, 'batch_size': x.shape[0] }

    def validation_epoch_end(self, validation_step_outputs):
        temp = 0.0
        count = 0
        for entry in validation_step_outputs:
            temp += entry['loss']
            count += 1
        avg_loss = temp / count
        self.log('val_loss', avg_loss, on_epoch=True, on_step=False, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        out_f = self.nn_f(x)
        out_g = self.nn_g(x)
        loss = self.loss_function(out_f, out_g, y[:,0], y[:,1], 'test')
        return { 'loss': loss, 'batch_size': x.shape[0] }

    def test_epoch_end(self, test_step_outputs):
        temp = 0.0
        count = 0
        for entry in test_step_outputs:
            temp += entry['loss']
            count += 1
        avg_loss = temp / count
        self.log('test_loss', avg_loss, on_epoch=True, on_step=False, logger=True)
