import os
import torch
import numpy as np
from scipy import signal
from scipy import linalg as la
from scipy import special as ss

# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from models.op import transition

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoding(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z

    def decode(self, z):
        h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, c_x, c_recon_x):
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    HIPPOD = F.mse_loss(c_recon_x, c_x, reduction='mean')

    # TODO: add linear and nonlinear CCA

    return BCE + KLD + HIPPOD

class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    N = args.hippo_dim
    input_dim = args.input_dim
    T = x.shape[0]

    A, B = transition('lmu', N)
    C = np.ones((1, N))
    D = np.zeros((1,))

    # dt, discretization options
    A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=1./T, method=args.ode_discrete_method)

    A = torch.Tensor(A).to(device).detach()
    B = torch.Tensor(B).to(device).detach()
    B = B.t()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    train_loss = 0
    dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(dataset, batch_size=args.vae_batch_size, shuffle=False, drop_last=True)

    for epoch in range(args.num_vae_epoch):
        c = torch.zeros(input_dim, N).to(device)
        c_ = torch.zeros(input_dim, N).to(device)
        for batch_idx, (batch_x, _) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(batch_x)

            #
            cs = []
            cs_ = []
            for i in range(args.vae_batch_size):
                # tmp = batch_x[i, :]
                c = F.linear(c, A) + batch_x[i, :].unsqueeze(-1) @ B
                cs.append(c)
                c_ = F.linear(c_, A) + recon_x[i, :].unsqueeze(-1) @ B
                cs_.append(c_)
            c_x = torch.stack(cs, dim=0)
            c_recon_x = torch.stack(cs_, dim=0)

            #
            loss = loss_function(recon_x, batch_x, mu, logvar, c_x, c_recon_x)
            # BCE = F.mse_loss(recon_x, batch_x, reduction='mean')
            # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            # loss = BCE + KLD
            loss.backward()
            train_loss = train_loss + loss.item()
            optimizer.step()
            c = c.detach().clone()
            c_ = c_.detach().clone()

            if batch_idx % 2 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(batch_x))
                # print(Line)
                logger.info(Line)

        #
        model.eval()
        recon_x,_,_ = model(x.detach())
        BCE = F.mse_loss(recon_x, x.detach(), reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
              epoch, BCE, train_loss / len(train_loader.dataset)))

        model.train()

    return model