import os
from pathlib import Path
import math
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import torch
from torch.nn.functional import cross_entropy
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

try:
    import wandb
except ImportError:
    print('wandb not available')
try:
    import ray
except ImportError:
    print('ray not available')

import data
import losses
import models
import ap


def cmdline_args():
    parser = ArgumentParser()
    # experiment config
    parser.add_argument('--project', default=None)
    parser.add_argument('--name', default='default')
    # dataset config
    parser.add_argument('--dataset', choices=['random', 'clevr', 'mnist', 'numbering'], default='numbering')
    parser.add_argument('--loss', choices=['hungarian_l2', 'hungarian_ce', 'hungarian_nl', 'chamfer', 'mse', 'ce'], default='hungarian_l2')
    parser.add_argument('--set_size', type=int, default=64)
    parser.add_argument('--set_dim', type=int, default=64)
    parser.add_argument('--input_dim', type=int, default=4)
    parser.add_argument('--dataset_size', type=int, default=64000)
    parser.add_argument('--n_obj_per_sample', type=int, default=4)
    parser.add_argument('--clevr_path', default='clevr')
    parser.add_argument('--clevr_image_input', action='store_true')
    parser.add_argument('--clevr_image_size', type=int, default=128)
    # training config
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr_drop_epoch', type=int, default=None)
    parser.add_argument('--checkpoint_path', default='checkpoints')
    parser.add_argument('--num_data_workers', type=int, default=0)
    parser.add_argument('--num_ray_workers', type=int, default=0)
    parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
    # model config
    parser.add_argument('--model', default='idspn', choices=['idspn', 'dspn', 'slot', 'deepsets', 'lstm', 'transformer', 'random_transformer'])
    parser.add_argument('--latent_dim', type=int, default=64)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--input_encoder', default='rnfs', choices=['fs','rnfs'])
    # idspn config
    parser.add_argument('--decoder_encoder', default='fs', choices=['fs','rnfs'])
    parser.add_argument('--decoder_lr', type=float, default=1.0)
    parser.add_argument('--decoder_iters', type=int, default=20)
    parser.add_argument('--decoder_momentum', type=float, default=0.9)
    parser.add_argument('--decoder_val_iters', type=int, default=None)
    parser.add_argument('--decoder_grad_clip', type=float)
    parser.add_argument('--decoder_it_schedule', action='store_true')
    parser.add_argument('--decoder_starting_set', action='store_true')
    # wandb config
    parser.add_argument('--no_wandb', dest='use_wandb', action='store_false')
    # eval config
    parser.add_argument('--progress_num_examples', type=int, default=0)
    parser.add_argument('--progress_path', default='progress')
    parser.add_argument('--eval_checkpoint', default=None)
    parser.add_argument('--test_after_training', action='store_true')

    args = parser.parse_args()
    
    if args.dataset == 'random':
        assert args.set_size > 0
        assert args.set_dim > 0
    elif args.dataset == 'clevr':
        args.set_size = 10
        args.set_dim = 19
    elif args.dataset == 'mnist':
        args.set_size = 342
        args.set_dim = 3
    elif args.dataset == 'numbering':
        assert args.set_size == args.set_dim

    if args.project is None:
        if args.dataset == 'random':
            args.project = f'random-dim{args.set_dim}-size{args.set_size}'
        elif args.dataset == 'clevr':
            args.project = 'clevr-' + ('images' if args.clevr_image_input else 'autoencode')
        elif args.dataset == 'mnist':
            args.project = 'mnist'
        elif args.dataset == 'numbering':
            args.project = f'numbering_results_ds{args.dataset_size}_best'
    return args


class SetPredictionModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters(args)
        if 'dspn' in args.model:
            self.net = models.DSPNModel(
                input_dim=self.hparams.input_dim,
                d_in=self.hparams.set_dim,
                d_hid=self.hparams.hidden_dim,
                d_latent=self.hparams.latent_dim,
                set_size=self.hparams.set_size,

                input_encoder=self.hparams.input_encoder,
                decoder_encoder=self.hparams.decoder_encoder,
                lr=args.decoder_lr,
                iters=args.decoder_iters,
                momentum=args.decoder_momentum,
                grad_clip=args.decoder_grad_clip,
                use_starting_set=args.decoder_starting_set,

                image_input=self.hparams.clevr_image_input,
                image_size=self.hparams.clevr_image_size,
                implicit=self.hparams.model == 'idspn',
            )
        elif args.model == 'slot':
            self.net = models.SlotAttentionModel(
                d_in=self.hparams.input_dim,
                d_hid=self.hparams.hidden_dim,
                d_out=self.hparams.set_dim,
                set_size=args.set_size,
            )
        elif args.model == 'deepsets':
            self.net = models.DSModel(self.hparams.input_dim, self.hparams.hidden_dim, self.hparams.set_dim)
        elif args.model == 'lstm':
            self.net = models.LSTMModel(self.hparams.input_dim, self.hparams.hidden_dim, self.hparams.set_dim)
        elif args.model == 'transformer':
            self.net = models.TransformerModel(self.hparams.input_dim, self.hparams.hidden_dim, self.hparams.set_dim, args.set_size)
        elif args.model == 'random_transformer':
            self.net = models.RandomTransformerModel(self.hparams.input_dim, self.hparams.hidden_dim, self.hparams.set_dim, args.set_size)

        if args.dataset == 'random':
            self.trainset = data.Objects(size=args.dataset_size, cardinality=args.set_size, dim=args.set_dim)
            self.valset = data.Objects(size=args.dataset_size // 10, cardinality=args.set_size, dim=args.set_dim)
        elif args.dataset.startswith('clevr'):
            self.trainset = data.CLEVR(args.clevr_path, 'train', image_input=args.clevr_image_input, image_size=self.args.clevr_image_size)
            self.valset = data.CLEVR(args.clevr_path, 'val', image_input=args.clevr_image_input, image_size=self.args.clevr_image_size)
        elif args.dataset == 'mnist':
            self.trainset = data.MNISTSetMasked(train=True)
            self.valset = data.MNISTSetMasked(train=False)
        elif args.dataset == 'numbering':
            self.trainset = data.NumberInput(n_samples=args.dataset_size, set_size=args.set_size, set_dim=args.input_dim, n_obj_per_sample=args.n_obj_per_sample)
            self.valset = data.NumberInput(n_samples=6400, set_size=args.set_size, set_dim=args.input_dim, n_obj_per_sample=args.n_obj_per_sample)
            self.testset = data.NumberInput(n_samples=64000, set_size=args.set_size, set_dim=args.input_dim, n_obj_per_sample=args.n_obj_per_sample)
            # self.valset = data.NumberInput(n_samples=128, set_size=args.set_size, set_dim=args.input_dim, n_obj_per_sample=args.input_dim//2)
        self.ap_prefix = ''

    def forward(self, x):
        input, gt_output = x
        output, set_grad = self.net(input)
        return output, gt_output, set_grad

    def training_step(self, batch, batch_nb):
        output, gt_output, set_grad = self(batch)

        if 'hungarian' in self.args.loss:
            loss, indices = losses.hungarian_loss_numbering(
                batch[0], output, gt_output, 
                num_workers=self.args.num_ray_workers, 
                ret_indices=True,
                loss_type=self.args.loss.split('_')[-1])
            loss = loss.mean(0)
            micro_acc = losses.hungarian_micro_accuracy(output, gt_output, indices).mean(0)
            macro_acc = losses.hungarian_macro_accuracy(output, gt_output, indices).mean(0)
        elif self.args.loss == 'chamfer':
            loss = losses.chamfer_loss(output, gt_output).mean(0)
        elif self.args.loss == 'mse':
            loss = torch.nn.functional.mse_loss(output, gt_output)
        else:
            gt_output = gt_output.argmax(-1)
            loss = torch.nn.functional.cross_entropy(output, gt_output)
        grad_norm = set_grad.norm(dim=[1, 2]).mean() if set_grad is not None else 0

        log_dict = dict(loss=loss, grad_norm=grad_norm, micro_acc=micro_acc, macro_acc=macro_acc)
        self.log_dict({k+"/train": v for k,v in log_dict.items()})

        # if batch_nb % 1000==0:
        #     print("input", batch[0][0].argmax(1))
        #     print("gt", gt_output[0].argmax(1))
        #     print("pred", output[0].argmax(1))

        return loss

    def eval_step(self, batch, batch_idx, suffix):
        output, gt_output, set_grad = self(batch)

        if 'hungarian' in self.args.loss:
            loss, indices = losses.hungarian_loss_numbering(
                batch[0], output, gt_output, 
                num_workers=self.args.num_ray_workers, 
                ret_indices=True,
                loss_type=self.args.loss.split('_')[-1])
            loss = loss.mean(0)
            micro_acc = losses.hungarian_micro_accuracy(output, gt_output, indices).mean(0)
            macro_acc = losses.hungarian_macro_accuracy(output, gt_output, indices).mean(0)
        elif self.args.loss == 'chamfer':
            loss = losses.chamfer_loss(output, gt_output).mean(0)
        elif self.args.loss == 'mse':
            loss = torch.nn.functional.mse_loss(output, gt_output)
        else:
            gt_output = gt_output.argmax(-1)
            loss = torch.nn.functional.cross_entropy(output, gt_output)
        grad_norm = set_grad.norm(dim=[1, 2]).mean() if set_grad is not None else 0
            
        if batch_idx == 0 and self.args.progress_num_examples > 0:
            path = os.path.join(self.args.progress_path, self.args.project, self.args.name, f"{self.global_step}.png")
            self.plot_pointset(output, gt_output, Path(path), n_examples=self.args.progress_num_examples)

        log_dict = dict(loss=loss, grad_norm=grad_norm, micro_acc=micro_acc, macro_acc=macro_acc)
        self.log_dict({k+suffix: v for k,v in log_dict.items()})

        if self.args.dataset == 'clevr':
            thresholds = [float('inf'), 1, 0.5, 0.25, 0.125, 0.0625]
            aps = ap.compute_ap(gt_output, output, thresholds)
            self.log_dict({f'{self.ap_prefix}ap/{threshold}': ap for threshold, ap in zip(thresholds, aps)})

        return loss

    def validation_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, "/val")
    
    def test_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, "/test")

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        if self.args.lr_drop_epoch is not None:
            scheduler = {
                'scheduler': torch.optim.lr_scheduler.StepLR(opt, step_size=self.args.lr_drop_epoch)
            }
            return [opt], [scheduler]
        return opt

    def train_dataloader(self):
        return DataLoader(
            self.trainset, 
            batch_size=self.args.batch_size, 
            shuffle=True, 
            num_workers=self.args.num_data_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valset, 
            batch_size=self.args.batch_size, 
            shuffle=False,
            num_workers=self.args.num_data_workers,
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.testset, 
            batch_size=self.args.batch_size, 
            shuffle=False,
            num_workers=self.args.num_data_workers,
        )
    
    # def on_train_epoch_start(self) -> None:
    #     if self.args.decoder_it_schedule:
    #         it = int(0.5 * self.args.decoder_iters)
    #         if self.current_epoch >= 0.5 * self.args.epochs:
    #             it = self.args.decoder_iters
    #         self.net.dspn.iters = it
    #     else:
    #         self.net.dspn.iters = self.args.decoder_iters

    # def on_val_epoch_start(self) -> None:
    #     self.net.dspn.iters = self.args.decoder_val_iters or self.args.decoder_iters
            
    def plot_pointset(self, pred, target, filename, n_examples):
        n_rows = n_cols = math.ceil(n_examples ** 0.5)
        fig, axs = plt.subplots(n_rows, n_cols, squeeze=False, figsize=(15,15))
        
        pred = pred.cpu().transpose(1, 2)
        target = target.cpu().transpose(1, 2)

        if self.args.dataset == 'mnist':
            lim = 0, 1
        else:
            lim = -3, 3

        for i in range(n_examples):
            ax = axs[i // n_cols, i % n_cols]
            if self.args.dataset == 'mnist':
                ax.scatter(target[i, 1], 1 - target[i, 0], marker='o', s=5**2)
                ax.scatter(pred[i, 1], 1 - pred[i, 0], marker='x', s=5**2)
            else:
                ax.scatter(target[i, 0], target[i, 1], marker='o', s=5**2)
                ax.scatter(pred[i, 0], pred[i, 1], marker='x', s=5**2)
            ax.axis("equal")
            ax.set_xlim(*lim)
            ax.set_ylim(*lim)
        
        fig.tight_layout()
        filename.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(filename)


def train(args):
    model = SetPredictionModel(args)

    if args.num_ray_workers > 0:
        ray.init(num_cpus=args.num_ray_workers, include_dashboard=False)

    if args.use_wandb:
        run = wandb.init(
            name=args.name,
            project=args.project,
            reinit=False,
            settings=wandb.Settings(start_method="fork"),
        )
        run.define_metric("macro_acc/val", summary="max")
        run.define_metric("micro_acc/val", summary="max")
        logger = WandbLogger(log_model=True)
        logger.watch(model.net)
        wandb.config.update(args)

    checkpoint_path = os.path.join(args.checkpoint_path, args.project, args.name)
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        limit_val_batches=0.1 if args.dataset == 'clevr' else 1.0,
        gpus=1,
        num_nodes=1,
        logger=logger if args.use_wandb else None,
        callbacks=[
            ModelCheckpoint(dirpath=checkpoint_path, monitor="macro_acc/val", mode="max"),
            # EarlyStopping(monitor="loss/val", patience=20, mode="min")
        ],
        check_val_every_n_epoch=args.check_val_every_n_epoch,
    )

    trainer.fit(model)
    trainer.test()  # test best model
    return model


def test(args, model=None, trainer=None):
    if model is None:
        model = SetPredictionModel.load_from_checkpoint(checkpoint_path=args.eval_checkpoint, args=args)
    if trainer is None:
        trainer = pl.Trainer(gpus=1, num_nodes=1)
    trainer.limit_val_batches = 1.0
    trainer.test(model, model.val_dataloader())


def main():
    args = cmdline_args()
    pl.seed_everything(args.seed)

    if args.eval_checkpoint is None:
        train(args)
    else:
        test(args)


if __name__ == "__main__":
    main()
