#!/usr/bin/python2.7

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import copy
import numpy as np
from tqdm import tqdm
from itertools import repeat
from feature_dataset import align_batch
import pdb


def channel_normalization(x):
    # Normalize by the highest activation
    max_values = torch.max(torch.abs(x), dim=1, keepdim=True)[0] + 1e-5
    out = x / max_values
    return out


class SpatialDropout(nn.Module):
    """
    input: (batch, timesteps, embedding)

    """

    def __init__(self, drop=0.5):
        super(SpatialDropout, self).__init__()
        self.drop = drop

    def forward(self, inputs, noise_shape=None):
        """
        @param: inputs, tensor
        @param: noise_shape, tuple
        """
        outputs = inputs.clone()
        if noise_shape is None:
            noise_shape = (inputs.shape[0], *repeat(1, inputs.dim() - 2), inputs.shape[-1])  # 默认沿着中间所有的shape

        self.noise_shape = noise_shape
        if not self.training or self.drop == 0:
            return inputs
        else:
            noises = self._make_noises(inputs)
            if self.drop == 1:
                noises.fill_(0.0)
            else:
                noises.bernoulli_(1 - self.drop).div_(1 - self.drop)
            noises = noises.expand_as(inputs)
            outputs.mul_(noises)
            return outputs

    def _make_noises(self, inputs):
        return inputs.new().resize_(self.noise_shape)


class ED_TCN(nn.Module):
    def __init__(self, num_layers, conv_len, num_f_maps, dim, num_classes):
        super(ED_TCN, self).__init__()
        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
        self.encoder_layers = nn.ModuleList([nn.Conv1d(num_f_maps, num_f_maps, conv_len, padding=conv_len // 2) for i in range(num_layers)])
        self.decoder_layers = nn.ModuleList([nn.Conv1d(num_f_maps, num_f_maps, conv_len, padding=conv_len // 2) for i in range(num_layers)])
        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
        self.dropout = SpatialDropout(0.3)
        self.pool = nn.MaxPool1d(2)

    def forward(self, x):
        x = F.normalize(x, dim=1)
        lens = []
        out = self.conv_1x1_in(x)
        for E in self.encoder_layers:
            lens.append(out.shape[-1])
            out = E(out)
            out = self.dropout(out)
            out = channel_normalization(F.relu(out))
            out = self.pool(out)
        for i, D in enumerate(self.decoder_layers):
            upsample = nn.Upsample(size=lens[-i-1], mode='nearest')
            out = upsample(out)
            out = D(out)
            out = self.dropout(out)
            out = channel_normalization(F.relu(out))
        out = self.conv_out(out)
        # out = F.softmax(self.conv_out(out), dim=1)
        return out


class Trainer:
    def __init__(self, num_layers, conv_len, num_f_maps, dim, num_classes):
        self.model = ED_TCN(num_layers, conv_len, num_f_maps, dim, num_classes)
        self.ce = nn.CrossEntropyLoss(ignore_index=-100)
        self.mse = nn.MSELoss(reduction='none')
        self.num_classes = num_classes

    def train(self, save_dir, data_loader, num_epochs, learning_rate, device, causal):
        self.model.train()
        self.model.to(device)
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

        for epoch in range(num_epochs):
            epoch_loss = 0
            correct = 0
            total = 0
            pbar = tqdm(data_loader, ncols=80)
            for batch in pbar:
                optimizer.zero_grad()
                batch = align_batch(batch, self.num_classes, causal)
                batch_input, batch_target, mask = batch
                batch_input, batch_target, mask = batch_input.to(device), batch_target.to(device), mask.to(device)

                optimizer.zero_grad()
                predictions = self.model(batch_input)

                loss = self.ce(predictions.transpose(2, 1).contiguous().view(-1, self.num_classes), batch_target.view(-1))

                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()

                _, predicted = torch.max(predictions.data, 1)

                correct += ((predicted == batch_target).float() * mask[:, 0, :].squeeze(1)).sum().item()
                total += torch.sum(mask[:, 0, :]).item()

                pbar.set_description("epoch %d  Loss %.4f  Acc %.4f" % (epoch, loss, correct / total))
            torch.save(self.model.state_dict(), save_dir + "/epoch-" + str(epoch + 1) + ".model")
            torch.save(optimizer.state_dict(), save_dir + "/epoch-" + str(epoch + 1) + ".opt")

    def predict(self, model_dir, data_loader, epoch, device):
        self.model.eval()
        n_frames = 0
        n_acc = 0
        with torch.no_grad():
            self.model.to(device)
            self.model.load_state_dict(torch.load(model_dir + "/epoch-" + str(epoch) + ".model", map_location=device))
            print("loaded " + model_dir + "/epoch-" + str(epoch) + ".model")
            pbar = tqdm(data_loader, ncols=80)
            for batch in pbar:
                batch_input, batch_target, _ = batch
                batch_input, batch_target = batch_input.to(device), batch_target.to(device)

                predictions = self.model(batch_input)
                _, predicted = torch.max(predictions.data, 1)

                n_frames += batch_target.shape[1]
                n_acc += (predicted == batch_target).sum()

        frame_acc = float(n_acc) / float(n_frames)
        print("n_frames: ", n_frames, "n_acc", n_acc)
        print("Frame accuracy: ", frame_acc)
