import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import random
import h5py
import os
import torch.utils.data
import math
from torch.utils.tensorboard import SummaryWriter


class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, kernel_size=3, padding=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=padding, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class DetectionNet(nn.Module):
    def __init__(self):
        super(DetectionNet, self).__init__()
        self.backbone = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),
                                      nn.BatchNorm2d(64),
                                      nn.ReLU(inplace=True),
                                      ResidualBlock(64, 64),
                                      ResidualBlock(64, 128, stride=2),
                                      ResidualBlock(128, 128, kernel_size=1, padding=0),
                                      ResidualBlock(128, 256, stride=2),
                                      ResidualBlock(256, 256, kernel_size=1, padding=0),
                                      ResidualBlock(256, 512, stride=2),
                                      ResidualBlock(512, 512, kernel_size=1, padding=0),
                                      )
        self.exist_branch = nn.Sequential(ResidualBlock(512, 512),
                                          nn.Conv2d(512, 1, kernel_size=1, padding=0),
                                          nn.Sigmoid()
                                          )
        self.reg_coord_branch = nn.Sequential(ResidualBlock(512, 512),
                                              nn.Conv2d(512, 2, kernel_size=1, padding=0),
                                              nn.Sigmoid()
                                              )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        feature = self.backbone(x / 100)
        exist = self.exist_branch(feature)
        coord = self.reg_coord_branch(feature)
        return exist, coord


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_type):
        data_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "detection_data")
        f = h5py.File(os.path.join(data_root, data_type + '.h5'))
        self.data = f['data'][:]
        self.exist = f['exist'][:]
        self.coord = f['coord'][:]

    def __getitem__(self, index):
        return self.data[index], self.exist[index], self.coord[index]

    def __len__(self):
        return self.data.shape[0]


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.4, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        pred = pred.view(-1, 1)
        target = target.view(-1, 1)
        pred = torch.cat((1 - pred, pred), dim=1)
        class_mask = torch.zeros_like(pred)
        class_mask.scatter_(1, target.view(-1, 1).long(), 1.)
        probs = (pred * class_mask).sum(dim=1).view(-1, 1)
        probs = probs.clamp(min=0.0001, max=1.0)
        log_p = probs.log()
        alpha = torch.ones_like(pred)
        alpha[:, 0] = alpha[:, 0] * self.alpha
        alpha[:, 1] = alpha[:, 1] * (1 - self.alpha)
        alpha = (alpha * class_mask).sum(dim=1).view(-1, 1)
        batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
        loss = batch_loss.sum() / target.sum()
        return loss


device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
n_epoch = 500
batch_size = 64
lr = 1e-3
lr_gamma = 0.8
lr_step = 10
net = DetectionNet()
net.to(device)
bce_loss = FocalLoss()
mse_loss = nn.MSELoss()


def loss_func(exist_pred, coord_pred, exist_truth, coord_truth):
    exist_loss = bce_loss(exist_pred, exist_truth)
    exist_sum = exist_truth.sum()
    coord_loss = ((coord_pred * exist_truth - coord_truth * exist_truth) ** 2).sum() / exist_sum
    return exist_loss + coord_loss


def train():
    writer = SummaryWriter('../runs/detection')
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    data_loader = torch.utils.data.DataLoader(Dataset('train'), batch_size=batch_size, shuffle=True, drop_last=False)
    test_data_loader = torch.utils.data.DataLoader(Dataset('test'), batch_size=batch_size, drop_last=False)
    lowest_test_loss = 99999
    for i in range(n_epoch):
        net.train()
        train_loss = 0
        train_acc = 0
        train_TPR = 0
        train_sample = 0
        for batch_index, data_batch in enumerate(data_loader):
            optimizer.zero_grad()
            data, exist_truth, coord_truth = data_batch
            exist_pred, coord_pred = net(data.to(device))
            loss = loss_func(exist_pred, coord_pred, exist_truth.to(device), coord_truth.to(device))
            loss.backward()
            train_loss += loss.detach().cpu().item()
            postive_pred = (exist_pred > 0.5).detach().cpu()
            train_acc += (postive_pred == exist_truth).float().mean().item()
            train_TPR += postive_pred.sum().item() / exist_truth.sum().item()
            train_sample += 1
            optimizer.step()
        scheduler.step()

        test_loss = 0
        test_acc = 0
        test_TPR = 0
        test_sample = 0
        net.eval()
        with torch.no_grad():
            for batch_index, data_batch in enumerate(test_data_loader):
                data, exist_truth, coord_truth = data_batch
                exist_pred, coord_pred = net(data.to(device))
                loss = loss_func(exist_pred, coord_pred, exist_truth.to(device), coord_truth.to(device))
                test_loss += loss.detach().cpu().item()
                postive_pred = (exist_pred > 0.5).detach().cpu()
                test_acc += (postive_pred == exist_truth).float().mean().item()
                test_TPR += postive_pred.sum().item() / exist_truth.sum().item()
                test_sample += 1
        # print("Epoch: {0} | Train Loss: {1} | Train Acc: {2} | Train TPR: {3} \n"
        #       "           Test Loss: {4} | Test Acc: {5} | Test TPR: {6}".
        #       format(i + 1,
        #              format(train_loss / train_sample, '.2e'),
        #              format(train_acc / train_sample, '.4f'),
        #              format(train_TPR / train_sample, '.4f'),
        #              format(test_loss / test_sample, '.2e'),
        #              format(test_acc / test_sample, '.4f'),
        #              format(test_TPR / test_sample, '.4f'), ))
        writer.add_scalars('loss', {'train_loss': train_loss / train_sample,
                                    'test_loss': test_loss / test_sample}, i + 1)
        writer.add_scalars('acc', {'train_loss': train_acc / train_sample,
                                   'test_loss': test_acc / test_sample}, i + 1)
        writer.add_scalars('tpr', {'train_loss': train_TPR / train_sample,
                                   'test_loss': test_TPR / test_sample}, i + 1)
        writer.flush()
        if lowest_test_loss > test_loss / test_sample:
            torch.save(net.state_dict(), "detection.pt")
            lowest_test_loss = test_loss / test_sample
    writer.close()


def test():
    dataset = Dataset('test')
    data, exist_truth, coord_truth = dataset.__getitem__(1)
    data = torch.from_numpy(data).unsqueeze(0)
    exist_truth = torch.from_numpy(exist_truth).unsqueeze(0)
    coord_truth = torch.from_numpy(coord_truth).unsqueeze(0)
    net.load_state_dict(torch.load("detection.pt", map_location=lambda storage, location: storage))
    net.to(device)
    net.eval()
    with torch.no_grad():
        exist_pred, coord_pred = net(data.to(device))
    exist_pred = exist_pred > 0.5

    fig, ax = plt.subplots(nrows=2, ncols=2)
    im = ax[0, 0].imshow(data[0, 0, :, :])
    plt.colorbar(im, ax=ax[0, 0])
    im = ax[1, 0].imshow(exist_truth[0, 0, :, :])
    plt.colorbar(im, ax=ax[1, 0])
    im = ax[1, 1].imshow(exist_pred[0, 0, :, :].float())
    plt.colorbar(im, ax=ax[1, 1])

    grid_size = data.shape[-1]
    range_size = 12
    data_pred = torch.zeros(grid_size, grid_size)
    xv, yv = torch.meshgrid([torch.arange(0, grid_size), torch.arange(0, grid_size)])
    dx = math.pi / 2 / grid_size
    sigma = 1.5 * dx
    dx8 = dx * 8
    grid_start = math.pi - 100 * dx
    grid = torch.cat((xv.unsqueeze(2), yv.unsqueeze(2)), dim=2).reshape(-1, 2) * dx
    num_vort = exist_truth.sum().cpu().long().item()

    for i in range(int(grid_size / 16)):
        for j in range(int(grid_size / 16)):
            if exist_truth[0, 0, i, j] == 1:
                coord = (coord_pred[0, :, i, j] + torch.tensor([i, j])) * dx8
                old_coord = coord + 1
                while (coord - old_coord).norm() > 1e-5:
                    old_coord = coord
                    cell = (coord / dx).long()
                    left = np.clip(cell[0] - range_size, a_min=0, a_max=199)
                    right = np.clip(cell[0] + range_size, a_min=0, a_max=199)
                    bottom = np.clip(cell[1] - range_size, a_min=0, a_max=199)
                    up = np.clip(cell[1] + range_size, a_min=0, a_max=199)
                    coord = (data[0, 0:1, left:right, bottom:up] *
                             grid[0, :, left:right, bottom:up]).reshape(2, -1).mean(dim=1) / data[0, 0, left:right, bottom:up].sum()
                cell = (coord / dx).long()
                left = np.clip(cell[0] - range_size, a_min=0, a_max=199)
                right = np.clip(cell[0] + range_size, a_min=0, a_max=199)
                bottom = np.clip(cell[1] - range_size, a_min=0, a_max=199)
                up = np.clip(cell[1] + range_size, a_min=0, a_max=199)
                value = data[0, 0, left:right, bottom:up].sum() * dx ** 2
                rij2 = (coord.unsqueeze(0) - grid).norm(dim=1) ** 2
                rij2 = rij2.reshape(grid_size, grid_size)
                data_pred[:, :] += value / (2 * math.pi * sigma ** 2) * torch.exp(- rij2 / (2 * sigma ** 2))

    im = ax[0, 1].imshow(data_pred)
    plt.colorbar(im, ax=ax[0, 1])

    fig, ax = plt.subplots(nrows=num_vort, ncols=2)
    index = 0
    grid_size = coord_truth.shape[-1]
    for i in range(grid_size):
        for j in range(grid_size):
            if exist_truth[0, 0, i, j] == 1:
                ax[index, 0].scatter(coord_truth[0, 0, i, j], coord_truth[0, 1, i, j])
                ax[index, 0].set_xlim([0, 1])
                ax[index, 0].set_ylim([0, 1])
                ax[index, 1].scatter(coord_pred[0, 0, i, j], coord_pred[0, 1, i, j])
                ax[index, 1].set_xlim([0, 1])
                ax[index, 1].set_ylim([0, 1])
                index += 1
    plt.show()


if __name__ == '__main__':
    train()
    # test()
