'''
=====
- Associated publication:
url: 
doi: 
github: 
=====
'''
import torch
import torch.nn as nn
import logging
import numpy as np
from .embedding_model import EmbeddingModel
from torch.autograd import Variable

class GrayScottEmbedding(EmbeddingModel):
    """Embedding Koopman model for the 3D Gray-Scott system. 

    Args:
        config (:class:`config.configuration_phys.PhysConfig`): Configuration class with transformer/embedding parameters

    Note:
        For more information on the Gray-Scott model see "Complex Patterns in a Simple System" by John E. Pearson;
        https://doi.org/10.1126/science.261.5118.189
    """
    model_name = "embedding_grayscott"

    def __init__(self, config):
        """Constructor method
        """
        super().__init__(config)

        self.observableNet = nn.Sequential(
            nn.Conv3d(2, 64, kernel_size=(5, 5, 5), stride=2, padding=2, padding_mode='circular'),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.02, inplace=True),
            # nn.BatchNorm3d(64),
            # 8, 32, 32, 32
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=2, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.02, inplace=True),
            # nn.BatchNorm3d(128),
            # 16, 16, 16, 16
            nn.Conv3d(128, 256, kernel_size=(3, 3, 3), stride=2, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.02, inplace=True),
            # 
            # 32, 8, 8, 8
            nn.Conv3d(256, 512, kernel_size=(3, 3, 3), stride=2, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.02, inplace=True),
            # nn.BatchNorm3d(128),
            # 64 * 4 * 4 * 4
            nn.Conv3d(512, 512, kernel_size=(3, 3, 3), stride=2, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.02, inplace=True),
            # 128 * 2 * 2 * 2
            # nn.Conv3d(128, 32, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            # nn.BatchNorm2d(64),
            # nn.ReLU(True),
            # nn.LayerNorm((64, 4, 4, 4), eps=config.layer_norm_epsilon)
        )

        self.observableNetFC = nn.Sequential(
            nn.Linear(512, 512),
            nn.LeakyReLU(0.02, inplace=True),
            nn.Linear(512, config.n_embd),
            # nn.LeakyReLU(0.02, inplace=True),
            # nn.Linear(config.n_embd, config.n_embd),
            nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
            # nn.BatchNorm1d(config.n_embd, eps=config.layer_norm_epsilon),
            # nn.Dropout(config.embd_pdrop)
        )

        self.recoveryNetFC = nn.Sequential(
            # nn.Linear(config.n_embd, config.n_embd),
            # nn.ReLU(True),
            nn.Linear(config.n_embd, 512),
            nn.LeakyReLU(1.0, inplace=True),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.02, inplace=True),
        )

        self.recoveryNet = nn.Sequential(
            
            # nn.Conv3d(32, 128, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            # nn.BatchNorm2d(64),
            # nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
            nn.Conv3d(512, 512, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            nn.LeakyReLU(0.02, inplace=True),
            nn.BatchNorm3d(512),

            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
            nn.Conv3d(512, 256, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.02, inplace=True),
            # 
            # 32, 8, 8, 8
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False),
            nn.Conv3d(256, 128, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.02, inplace=True),
            # 
            # 16, 16, 16, 16
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False),
            nn.Conv3d(128, 64, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.02, inplace=True),
            # 8, 32, 32, 32
            nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False),
            nn.Conv3d(64, 64, kernel_size=(3, 3, 3), stride=1, padding=1, padding_mode='circular'),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.02, inplace=True),

            nn.Conv3d(64, 2, kernel_size=(1, 1, 1), stride=1, padding=0, padding_mode='circular'),
            nn.Sigmoid()
        )

        # Learned Koopman operator
        # Learns skew-symmetric matrix with a diagonal
        # self.kMatrixDiagNet = nn.Sequential(nn.Linear(2, 50), nn.ReLU(), nn.Linear(50, self.config.n_embd))
        # self.kMatrixDiag = torch.zeros(self.obsdim)
        # config.n_embd = 64*64
        self.kMatrixDiag = nn.Parameter(torch.ones(config.n_embd))

        xidx = []
        yidx = []
        for i in range(1, 10):
            yidx.append(np.arange(i, self.config.n_embd))
            xidx.append(np.arange(0, self.config.n_embd - i))

        self.xidx = torch.LongTensor(np.concatenate(xidx))
        self.yidx = torch.LongTensor(np.concatenate(yidx))
        # self.kMatrixUT = nn.Sequential(nn.Linear(2, 50), nn.ReLU(), nn.Linear(50, self.xidx.size(0)))
        self.kMatrixUT = nn.Parameter(0.01 * torch.rand(self.xidx.size(0)))

        # Normalization occurs inside the model
        self.register_buffer('mu', torch.tensor(0.))
        self.register_buffer('std', torch.tensor(1.))
        print('Number of embedding parameters: {}'.format(super().num_parameters))

    def forward(self, x):
        """Forward pass

        Args:
            x (torch.Tensor): [B, 1, H, W, D] Input feature tensor

        Returns:
            (tuple): tuple containing:

                | (torch.Tensor): [B, config.n_embd] Koopman observables
                | (torch.Tensor): [B, 2, H, W, D] Recovered feature tensor
        """
        # Encode
        # x = torch.cat([x, f.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(x[:,:1]), k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(x[:,:1])], dim=1)
        x = self._normalize(x)
        g0 = self.observableNet(x)
        # g = torch.cat([g0.view(g0.size(0),-1), 10*f.unsqueeze(-1), 10*k.unsqueeze(-1)], dim=1)
        g = self.observableNetFC(g0.view(g0.size(0),-1))
        # Decode
        out0 = self.recoveryNetFC(g).view(-1, 512, 1, 1, 1)
        # out = g.view(-1, self.config.n_embd//64, 4, 4, 4)
        out = self.recoveryNet(out0)
        xhat = self._unnormalize(out)
        return g, xhat, g0, out0, self._unnormalize(self.recoveryNet(g0))

    def embed(self, x):
        """Embeds tensor of state variables to Koopman observables

        Args:
            x (torch.Tensor): [B, 1, H, W, D] Input feature tensor

        Returns:
            (torch.Tensor): [B, config.n_embd] Koopman observables
        """
        # x = torch.cat([x, f.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(x[:,:1]), k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(x[:,:1])], dim=1)
        x = self._normalize(x)
        g0 = self.observableNet(x)
        # g = torch.cat([g0.view(g0.size(0), -1), 10*f.unsqueeze(-1), 10*k.unsqueeze(-1)], dim=1)
        g = self.observableNetFC(g0.view(g0.size(0),-1))
        return g

    def recover(self, g):
        """Recovers feature tensor from Koopman observables

        Args:
            g (torch.Tensor): [B, config.n_embd] Koopman observables

        Returns:
            (torch.Tensor): [B, 2, H, W, D] Physical feature tensor
        """
        out = self.recoveryNetFC(g).view(-1, 512, 1, 1, 1)
        # out = g.view(-1, self.config.n_embd//64, 4, 4, 4)
        out = self.recoveryNet(out)
        x = self._unnormalize(out)
        return x

    def koopmanOperation(self, g):
        """Applies the learned koopman operator on the given observables.

        Args:
            g (torch.Tensor): [B, config.n_embd] Koopman observables

        Returns:
            (torch.Tensor): [B, config.n_embd] Koopman observables at the next time-step
        """
        # Koopman operator
        kMatrix = Variable(torch.zeros(g.size(0), self.config.n_embd, self.config.n_embd)).to(self.devices[0])
        # Populate the off diagonal terms
        kMatrix[:, self.xidx, self.yidx] = self.kMatrixUT
        kMatrix[:, self.yidx, self.xidx] = -self.kMatrixUT
        # Populate the diagonal
        ind = np.diag_indices(kMatrix.shape[1])
        # self.kMatrixDiag = self.kMatrixDiagNet(torch.Tensor([1.,1.]).to(self.devices[0]))
        kMatrix[:, ind[0], ind[1]] = self.kMatrixDiag

        # Apply Koopman operation
        gnext = torch.bmm(kMatrix, g.unsqueeze(-1))
        self.kMatrix = kMatrix

        return gnext.squeeze(-1) # Squeeze empty dim from bmm

    @property
    def koopmanOperator(self, requires_grad=True):
        """Current Koopman operator

        Args:
            requires_grad (bool, optional): If to return with gradient storage, defaults to True
        """
        if not requires_grad:
            return self.kMatrix.detach()
        else:
            return self.kMatrix

    @property
    def koopmanDiag(self):
        return self.kMatrixDiag

    def _normalize(self, x):
        x = (x - self.mu.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)) / self.std.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        return x

    def _unnormalize(self, x):
        return self.std.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * x + self.mu.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)


class GrayScottEmbeddingTrainer(nn.Module):
    """Training head for the Gray-Scott embedding model for parallel training

    Args:
        config (:class:`config.configuration_phys.PhysConfig`) Configuration class with transformer/embedding parameters
    """
    def __init__(self, config):
        """Constructor method
        """
        super().__init__()
        self.embedding_model = GrayScottEmbedding(config)

    def forward(self, x_data=None, permute=False):
        '''
        Trains model for a single epoch
        '''
        self.embedding_model.train()
        device = self.embedding_model.devices[0]

        loss_reconstruct = 0
        mseLoss = nn.MSELoss()
        
        # Random permuting
        if permute:
            shifts = (np.random.randint(0, x_data.size(-3)), np.random.randint(0, x_data.size(-2)), np.random.randint(0, x_data.size(-1)))
            x_data = torch.roll(x_data, shifts=shifts, dims=(-3, -2, -1))

        xin0 = x_data[:,0].to(device) # Time-step
        # Model forward for both time-steps
        g0, xRec0, g1, g2, xRec1 = self.embedding_model(xin0)
        loss = (1e4)*mseLoss(xin0 , xRec0) + (1e3)*mseLoss(g1 , g2)
        # loss = (1e3)*mseLoss(g1 , g2) + (1e3)*mseLoss(xin0 , xRec1)
        loss_reconstruct = loss_reconstruct + mseLoss(xin0, xRec0).detach()

        g1_old = g0
        # Koopman transform
        for t0 in range(1, x_data.shape[1]):
            xin0 = x_data[:,t0,:].to(device) # Next time-step
            g1, xRec1, g2, g3, xRec2 = self.embedding_model(xin0)

            g1Pred = self.embedding_model.koopmanOperation(g1_old)
            xgRec1 = self.embedding_model.recover(g1Pred)

            loss = loss + mseLoss(xgRec1, xin0) + (1e4)*mseLoss(xRec1, xin0) + (1e3)*mseLoss(g2, g3) \
                + (1e-3)*torch.sum(torch.pow(self.embedding_model.koopmanOperator, 2))
            # loss = loss + mseLoss(xgRec1, xin0) + (1e3)*(mseLoss(xRec2, xin0) + mseLoss(g2, g3))  \
            #    + (1e-3)*torch.sum(torch.pow(self.embedding_model.koopmanOperator, 2))

            loss_reconstruct = loss_reconstruct + mseLoss(xRec1, xin0).detach()
            g1_old = g1Pred

        return loss, loss_reconstruct

    def save_model(self, *args, **kwargs):
        """
        Saves the embedding model
        """
        self.embedding_model.save_model(*args, **kwargs)


    def load_model(self, *args, **kwargs):
        """
        Load the embedding model
        """
        self.embedding_model.load_model(*args, **kwargs)
