#!/usr/bin/env python

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
from tqdm import tqdm

from . import utils
import pix2sym.utils.nn

class Autoencoder(utils.nn.Network):
    def __init__(self, input_shape, n_latent_dims=256, learning_rate=0.01, alpha=1e-6, beta=1e-6, norm='l1', learn_mean_img=False):
        super().__init__()
        self.input_shape = tuple(input_shape)
        self.n_channels = input_shape[0]
        self.n_latent_dims = n_latent_dims
        self.learning_rate = learning_rate
        self.alpha = alpha
        self.beta = beta
        assert norm in ['l0', 'lq', 'l1', 'l2']
        self.norm = norm

        self.mu = nn.Parameter(torch.zeros(input_shape, dtype=torch.float32), requires_grad=learn_mean_img)
        self.name = utils.random_word

        encoder_layers = [
            nn.Conv2d(self.n_channels, 3, kernel_size=3, stride=1), nn.ReLU(),
            nn.Conv2d(3, 3, kernel_size=6, stride=2), nn.ReLU(),
            # nn.Conv2d(16, 16, kernel_size=2, stride=1), nn.ReLU(),
            # # nn.Conv2d(16, 16, kernel_size=3, stride=1), nn.ReLU(),
        ]
        dummy = torch.zeros((1,)+input_shape, dtype=torch.float32)
        _, *shape_2d = nn.Sequential(*encoder_layers).forward(dummy).size()
        # shape_2d = input_shape
        shape_flat = np.prod(shape_2d)

        encoder_layers.extend([
            utils.nn.Reshape(-1, shape_flat),
            nn.Linear(shape_flat, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, n_latent_dims),
        ])
        self.encoder = nn.Sequential(*encoder_layers)

        decoder_layers = [
            nn.Linear(n_latent_dims, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, shape_flat), nn.ReLU(),
            utils.nn.Reshape(-1, *shape_2d),
            # nn.ConvTranspose2d(16, 16, kernel_size=3, stride=1), nn.ReLU(),
            # nn.ConvTranspose2d(16, 16, kernel_size=2, stride=1), nn.ReLU(),
            nn.ConvTranspose2d(3, 3, kernel_size=6, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(3, self.n_channels, kernel_size=3, stride=1),
            nn.Tanh(),
        ]
        self.decoder = nn.Sequential(*decoder_layers)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

    def __str__(self):
        s = 'Network name: {}\n'.format(self.name)
        s += super().__str__()
        return s

    def forward(self, x):
        z = self.encoder(x-self.mu)
        x_hat = self.decoder(z)+self.mu
        return x_hat

    def layers(self):
        enc_sequential = list(self.encoder.modules())[0]
        dec_sequential = list(self.decoder.modules())[0]
        enc_layers = list(enc_sequential.modules())[1:]
        dec_layers = list(dec_sequential.modules())[1:]
        return enc_layers + dec_layers

    def __setattr__(self, name, value):
        if name == 'mu':
            if type(value) is torch.Tensor:
                assert value.size() == self.mu.size()
                # Setting self.mu with a Tensor shouldn't wipe out the Parameter
                self.mu.data.copy_(value)
                return
            elif type(value) is not torch.nn.Parameter:
                raise TypeError
        super().__setattr__(name, value)

    def encode(self, x):
        return self.encoder(x - self.mu)

    def decode(self, z):
        return self.decoder(z) + self.mu

    def reconstruct(self, x):
        return self.decode(self.encode(x))

    def compute_losses(self, states, next_states):
        x1, x2 = states, next_states
        z1 = self.encode(x1)
        z2 = self.encode(x2)
        x1_hat = self.decode(z1)
        x2_hat = self.decode(z2)
        reconstruction1 = utils.nn.mse_loss(input=x1_hat, target=x1)
        reconstruction2 = utils.nn.mse_loss(input=x2_hat, target=x2)
        temporal_smoothing = utils.nn.norm_loss[self.norm](z2-z1)
        cross_correlation = utils.nn.mean_xcorr(z2-z1)
        losses = [
            0,
            reconstruction1,
            reconstruction2,
            temporal_smoothing,
            cross_correlation,
        ]
        coefs = [0, 1, 1, self.alpha, self.beta]
        losses[0] = sum([coef*loss for coef, loss in zip(coefs, losses)])
        return losses

    def reconstruction_error(self, x):
        x_hat = self(x)
        mse = utils.nn.mse_loss(input=x_hat, target=x)
        return mse.item()

    def train_batch(self, states, next_states):
        self.optimizer.zero_grad()
        losses = self.compute_losses(states, next_states)
        loss = losses[0]
        loss.backward()
        self.optimizer.step()
        losses = [l.item() for l in losses]
        return losses

    def save(self):
        os.makedirs('models', exist_ok=True)
        time_str = utils.get_time_string()
        model_file = 'models/{}_{}.pytorch'.format(self.name, time_str)
        print('Saving {}...'.format(model_file))
        torch.save(self.state_dict(), model_file)
        return model_file

    def load(self, model_file, force_cpu=False):
        print('Loading {}...'.format(model_file))
        map_loc = 'cpu' if force_cpu else None
        state_dict = torch.load(model_file, map_location=map_loc)
        self.load_state_dict(state_dict)
        self.name = os.path.split(model_file)[-1].split('_')[0]


def test_shapes():
    ac = Autoencoder(input_shape=(1,84,84), n_latent_dims=10)
    x = torch.zeros([32,1,84,84], dtype=torch.float32)
    z = ac.encode(x)
    assert z.size() == (32, 10)
    x_hat = ac.decode(z)
    assert x_hat.size() == x.size()

def test_mu():
    ac = Autoencoder(input_shape=(1,84,84), n_latent_dims=10)
    assert ac.mu.requires_grad == False

    x = torch.ones([1,84,84], dtype=torch.float32)
    ac.mu = torch.nn.Parameter(x, requires_grad=True)
    assert ac.mu.requires_grad == True
    assert ac.mu[0,0,0] == 1

    ac.mu = torch.zeros_like(x)
    assert type(ac.mu) is torch.nn.Parameter
    assert ac.mu.requires_grad == True
    assert ac.mu[0,0,0] == 0

def test_losses():
    a = torch.arange(0,5, dtype=torch.float32)
    a = torch.stack([a,a])
    b = torch.arange(1,6, dtype=torch.float32)*2
    b = torch.stack([b,b])
    c = torch.tensor([-1, 1, -1, 1, -1], dtype=torch.float32)
    c = torch.stack([c,c])
    x = a
    y = b*c
    diff = x-y
    abs_diff = torch.abs(diff)
    l1 = torch.mean(torch.sum(abs_diff, dim=1))
    l2 = torch.mean(torch.sqrt(torch.sum(abs_diff**2, dim=1)))
    mse = torch.mean(torch.abs(a-b*c)**2)
    assert l1 == utils.nn.l1_loss(x-y)
    assert l2 == utils.nn.l2_loss(x-y)
    assert mse == utils.nn.mse_loss(x,y)

    input_shape = (1,84,84)
    aes = []
    for norm in ['l0', 'lq', 'l1', 'l2']:
        utils.reset_seeds(0)
        aes.append(Autoencoder(input_shape=input_shape, alpha=1.0, norm=norm))
    x1 = torch.tensor(np.ones((1,)+input_shape), dtype=torch.float32)
    x2 = 0.9 * x1
    l0, lq, l1, l2 = tuple([ae.compute_losses(x1,x2)[3] for ae in aes])
    assert l0 not in [lq, l1, l2]
    assert lq not in [l1, l2]
    assert l1 != l2

def test_save_and_load():
    orig = Autoencoder(input_shape=(3,84,84), n_latent_dims=10)
    orig.mu = torch.ones_like(orig.mu)
    assert np.all(orig.mu == torch.ones_like(orig.mu))

    filename = orig.save()
    dup = Autoencoder(input_shape=(3,84,84), n_latent_dims=10)
    dup.load(filename)

    assert len(orig.layers()) == len(dup.layers())

    orig_conv1_weights = orig.layers()[0].weight
    dup_conv1_weights = dup.layers()[0].weight
    assert np.all(dup_conv1_weights == orig_conv1_weights)
    assert np.all(dup.mu == orig.mu)
    assert dup.name == orig.name

def test_train_and_test():
    input_shape=(1,84,84)
    ac = Autoencoder(input_shape=input_shape, n_latent_dims=10)
    x1 = torch.tensor(np.ones((1,)+input_shape), dtype=torch.float32)
    x2 = 0.9 * x1

    train_loss1 = ac.train_batch(x1, x2)
    train_loss2 = ac.train_batch(x1, x2)
    assert np.any(train_loss1 != train_loss2)

    test_loss1 = ac.compute_losses(x1,x2)[0]
    test_loss2 = ac.compute_losses(x1,x2)[0]
    assert np.all(test_loss1 == test_loss2)

def overfit_single_batch():
    img_shape=(1,84,84)
    img_size=img_shape[1:]
    net = Autoencoder(input_shape=img_shape, n_latent_dims=10)

    x1 = np.kron([[1, -1] * 4, [-1, 1] * 4] * 4, np.ones((4, 4)))
    x1 = np.array(cv2.resize(x1, img_size))
    x1 = torch.Tensor(x1)
    x1 = x1.view(1,1,img_size[1],img_size[0])
    x2 = x1 + 0.3 * torch.randn_like(x1)

    pre1 = net(x1)
    pre2 = net(x2)

    running_loss = 0.0
    for i in tqdm(range(1000)):
        # Single step of SGD (i.e. in training loop)
        loss = net.train_batch(x1,x2)[0]
        running_loss += loss.item()
        if i % 200 == 199:    # print every 200 mini-batches
            tqdm.write('[%d] loss: %.3f' %
                  (i + 1, running_loss / 200))
            running_loss = 0.0

    post1 = net(x1)
    post2 = net(x2)

    with torch.no_grad():
        images = torch.cat((pre1, post1, x1, pre2, post2, x2),0).numpy()
        utils.show_batch_images(images, titles=['Before','After','Target'])

def main():
    utils.reset_seeds(0)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Device: {}'.format(device))

    ac = Autoencoder(input_shape=(1,84,84))
    ac.summary()
    test_shapes()
    test_mu()
    test_train_and_test()
    test_losses()
    # test_save_and_load()
    # overfit_single_batch()
    print('Testing complete.')

if __name__ == '__main__':
    main()
