import torch
import torch.nn as nn
import model.models as models


class FCAE(nn.Module):
    def __init__(self, input_dim, ic_dim):
        super(FCAE, self).__init__()
        self.latent_dim = 32
        if len(input_dim) > 2:
            self.input_dim = input_dim[0] * input_dim[1] * input_dim[2]
        else:
            self.input_dim = input_dim[0] * input_dim[1]
        self.encoder_fc = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, 128),
            nn.ReLU(True),
            nn.Linear(128, self.latent_dim))

        self.decoder_fc = nn.Sequential(
            nn.Linear(self.latent_dim + ic_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, self.input_dim))

    def forward(self, x):
        if len(x.shape) > 3:
            input_x = x.view(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
        else:
            input_x = x.view(x.shape[0], x.shape[1] * x.shape[2])
        z = self.encoder_fc(input_x)
        input_z = torch.cat((z, x[:, 0]), dim=1)
        out = self.decoder_fc(input_z)

        if len(x.shape) > 3:
            out = out.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
        else:
            out = out.view(x.shape[0], x.shape[1], x.shape[2])
        return out, z


class PDExplain(nn.Module):
    def __init__(self, input_dim, x_dim, pde_type):
        super(PDExplain, self).__init__()
        self.ae_model = FCAE(input_dim, x_dim)
        self.context_to_params_model = get_context_to_params_model(pde_type)(self.ae_model.latent_dim)
        self.loss_func = get_pde_loss(pde_type)

    def forward(self, t, f, sol_context):
        recon_sol, context = self.ae_model(sol_context)
        params = self.context_to_params_model(t, f, context)
        return recon_sol, params

    def forward_multiple_t(self, t, f, sol_context):
        recon_sol, context = self.ae_model(sol_context)
        context = context.expand(t.shape[0], context.shape[1])
        params = self.context_to_params_model(t, f, context)
        return recon_sol, params


def get_context_to_params_model(pde_type):
    if pde_type == 'const':
        return models.const_pde_model.ContextToParamsConst
    elif pde_type == 'burgers':
        return models.burgers_model.ContextToParamsConst
    else:
        raise ValueError('pde_type should be [const/burgers]')


def get_pde_loss(pde_type):
    if pde_type == 'const':
        return models.const_pde_model.calc_const_pde_loss
    elif pde_type == 'burgers':
        return models.burgers_model.calc_burgers_pde_loss
    else:
        raise ValueError('pde_type should be [const/burgers]')
