import argparse
import importlib
import json
import os
import torch
import torch.utils.data
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from datasets.datasets_pro import get_dataloader
from models.heatmaps import gen_heatmaps, heatmap2keypoints
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='detector_res')
parser.add_argument('--generator_log', type=str, default='log/generator_pro_discriminator_pro_k10_tau0.01_celeba_128')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--data_root', type=str, default='../data/celeba')
parser.add_argument('--wild', action="store_true")
parser.add_argument('--num_workers', type=int, default=6)
parser.add_argument('--checkpoint', type=int, default=199)
args = parser.parse_args()

with open(os.path.join(args.generator_log, 'parameters.json'), 'rt') as f:
    generator_args = json.load(f)

args.log = '{0}_{1}_k{2}_{3}_{4}'.format(args.model, generator_args['model'], generator_args['n_keypoints'], generator_args['class_name'], generator_args['n_embedding'])
args.log = os.path.join('log', args.log)

os.makedirs(args.log, exist_ok=True)
with open(os.path.join(args.log, 'parameters.json'), 'wt') as f:
    json.dump(args.__dict__, f, indent=2)

n_epochs = [100, 100, 100]

stage_list = np.concatenate([np.ones(n_epochs[i]) * i for i in range(len(n_epochs))]).astype(np.int)

offset = 50

alpha_list = []
for i in range(len(n_epochs)):
    alpha_list.append(np.linspace(0, 1, offset))
    alpha_list.append(np.linspace(1, 1, n_epochs[i]-offset))
alpha_list = np.concatenate(alpha_list)

if args.checkpoint < len(stage_list):
    stage = stage_list[args.checkpoint]
    alpha = alpha_list[args.checkpoint]
else:
    stage = stage_list[-1]
    alpha = 1

batch_sizes = [256, 128, 64]
image_sizes = [64, 128, 256]
image_size = image_sizes[stage]

device = 'cuda:0'
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
gen_model = importlib.import_module('models.' + generator_args['model'])
generator = gen_model.Generator({'z_dim': generator_args['z_dim'], 'n_keypoints': generator_args['n_keypoints'],
                             'n_embedding': generator_args['n_embedding'], 'tau': generator_args['tau']}).to(device)
gen_checkpoint = torch.load(os.path.join(generator_args['log'], 'checkpoints', 'epoch_{}.model'.format(args.checkpoint)),
                            map_location=lambda storage, location: storage)
generator.load_state_dict(gen_checkpoint['generator'])
del gen_checkpoint
det_model = importlib.import_module('models.' + args.model)
detector = det_model.Detector(generator_args).to(device)
optim = torch.optim.Adam(detector.parameters(), lr=args.lr)

generator = torch.nn.DataParallel(generator)

mafl_class = 'mafl' if args.wild else 'mafl_wild'

mafl_train_dataloader = get_dataloader(data_root=args.data_root, image_size=image_size, class_name=mafl_class+'_train',
                                           batch_size=batch_sizes[stage],
                                           num_workers=args.num_workers, pin_memory=True, drop_last=False)
mafl_test_dataloader = get_dataloader(data_root=args.data_root, image_size=image_size, class_name=mafl_class+'_test',
                                      batch_size=batch_sizes[stage],
                                      num_workers=args.num_workers, pin_memory=True, drop_last=False)


class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.imgs = []
        self.keypoints = []
        generator.module.eval()
        with torch.no_grad():
            while len(self.imgs) < 200000 // batch_sizes[stage]:
                input_batch = {'input_noise{}'.format(noise_i): torch.randn(batch_sizes[stage], *noise_shape).to(device)
                               for noise_i, noise_shape in enumerate(generator.module.noise_shapes)}
                self.imgs.append(generator(input_batch, stage, alpha)['img'].detach().cpu())
                self.keypoints.append((generator.module.gen_keypoints(input_batch)).detach().cpu())
        self.imgs = torch.cat(self.imgs)
        self.keypoints = torch.cat(self.keypoints)
        self.transform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx]), 'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


def get_generated_dataloader():
    return torch.utils.data.DataLoader(Dataset(),
              batch_size=batch_sizes[stage], shuffle=True,
              num_workers=generator_args['num_workers'], pin_memory=True, drop_last=True)


generated_dataloader = get_generated_dataloader()


def train_one_epoch():
    generator.eval()
    detector.train()
    total_loss = 0

    for batch_index, batch in enumerate(generated_dataloader):
        optim.zero_grad()
        batch = {key: value.to(device) for key, value in batch.items()}
        loss = F.mse_loss(detector(batch)['heatmap'], gen_heatmaps(batch['keypoints'], heatmap_size=image_size, tau=0.01))
        loss.backward()
        total_loss += loss.item()
        optim.step()

    return total_loss / len(generated_dataloader)


def evaluate():
    eval_dir = os.path.join(args.log, 'eval')
    os.makedirs(eval_dir, exist_ok=True)

    detector.eval()

    train_X = []
    train_y = []
    test_X = []
    test_y = []
    with torch.no_grad():
        for batch_index, real_batch in enumerate(mafl_train_dataloader):
            real_batch['img'] = real_batch['img'].to(device)
            train_X.append(heatmap2keypoints(detector(real_batch)['heatmap']).detach().cpu())
            train_y.append(real_batch['keypoints'])
        train_X = torch.cat(train_X)
        train_y = torch.cat(train_y)
        train_X = train_X.reshape(train_X.shape[0], -1)
        train_y = train_y.reshape(train_y.shape[0], -1)

        try:
            beta = (train_X.T @ train_X).inverse() @ train_X.T @ train_y
        except:
            beta = (train_X.T @ train_X + torch.eye(20)).inverse() @ train_X.T @ train_y

        for batch_index, real_batch in enumerate(mafl_test_dataloader):
            real_batch['img'] = real_batch['img'].to(device)
            test_X.append(heatmap2keypoints(detector(real_batch)['heatmap']).detach().cpu())
            test_y.append(real_batch['keypoints'])
        test_X = torch.cat(test_X)
        test_y = torch.cat(test_y)
        test_X = test_X.reshape(test_X.shape[0], -1)
        test_y = test_y.reshape(test_y.shape[0], -1)
        eval_loss = F.mse_loss(test_X@beta, test_y)
        unnormalized_loss = (test_X@beta - test_y).reshape(test_X.shape[0], 5, 2).norm(dim=-1)
        eye_distance = (test_y.reshape(test_y.shape[0], 5, 2)[:, 0, :] - test_y.reshape(test_y.shape[0], 5, 2)[:, 1, :]).norm(dim=-1)
        normalized_loss = (unnormalized_loss / eye_distance.unsqueeze(1)).mean()

        plot_size = 64
        samples = torch.stack([mafl_test_dataloader.dataset[i]['img'] for i in range(plot_size)])
        pred_heatmaps = detector({'img': samples.to(device)})['heatmap']
        pred_keypoints = heatmap2keypoints(pred_heatmaps).detach().cpu().squeeze()
        reg_keypoints = (pred_keypoints.reshape(pred_keypoints.shape[0], -1) @ beta).reshape(pred_keypoints.shape[0], 5, 2)

    fig = plt.figure(figsize=(8, 8), dpi=256)
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.permute((1, 2, 0)) * 0.5 + 0.5)
        plt.scatter(pred_keypoints[i, :, 1], pred_keypoints[i, :, 0], s=20, marker='+')

    plt.savefig(os.path.join(eval_dir, '{}.png'.format(epoch)), bbox_inches='tight')
    plt.close(fig)

    fig = plt.figure(figsize=(8, 8), dpi=256)
    fig.suptitle(str(eval_loss) + '    {}'.format(normalized_loss))
    gs = gridspec.GridSpec(8, 8)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.permute((1, 2, 0)) * 0.5 + 0.5)
        plt.scatter(reg_keypoints[i, :, 1], reg_keypoints[i, :, 0], s=20, marker='+')

    plt.savefig(os.path.join(eval_dir, '{}_regressed.png'.format(epoch)), bbox_inches='tight')
    plt.close(fig)

    return eval_loss.item(), normalized_loss.item(), beta


if __name__ == '__main__':
    writer = SummaryWriter(os.path.join(args.log, 'runs'))
    checkpoint_dir = os.path.join(args.log, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(999):
        train_loss = train_one_epoch()

    test_loss, test_normalized_loss, beta = evaluate()

    torch.save(
        {
            'detector': detector.state_dict(),
            'optim': optim.state_dict(),
            'beta': beta
        },
        os.path.join(checkpoint_dir, 'detector.model')
    )
