import argparse
import json
import numpy as np
import torch
import os
import wandb
from time import time
from utils.helper import EarlyStopping

# from tensorboardX import SummaryWriter
from time import sleep
from torch import optim
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from tqdm import tqdm

import data as data_
import nflows.utils.torchutils as utils


import utils as cutils
from nflows import transforms, distributions, flows
import nflows.nn.nets.resnet as nn_


def main(cfg, y, y_val, y_test):
    # torch.utils.cpp_extension.get_default_build_root()

    model_dict = {
        "rq-nsf": "rq-autoregressive",
        "maf": "affine-autoregressive",
    }
    args = cutils.ObjectView(
        {
            "dataset_name": "miniboone",
            "val_frac": 1.0,
            "val_batch_size": 2048,
            "num_training_steps": cfg.model.maxiter * 100,
            "anneal_learning_rate": 1,
            "grad_norm_clip_value": 5.0,
            "base_transform_type": model_dict[cfg.model.model_class],
            "linear_transform_type": "lu",
            "apply_unconditional_transform": 1,
            "use_batch_norm": 0,
            "train_batch_size": cfg.model.train_batch_size,
            "learning_rate": cfg.model.learning_rate,
            "num_flow_steps": cfg.model.num_flow_steps,
            "hidden_features": cfg.model.hidden_features,
            "num_bins": cfg.model.num_bins,
            "num_transform_blocks": cfg.model.num_transform_blocks,
            "dropout_probability": cfg.model.dropout_probability,
            "monitor_interval": cfg.model.monitor_interval,
            "monitor_interval2": cfg.model.monitor_interval2,
            "tail_bound": 3,
            "seed": cfg.base.seed,
        }
    )

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # torch.set_default_tensor_type("torch.cuda.FloatTensor")

    # create data
    train_dataset = data.TensorDataset(torch.tensor(y))
    train_loader = data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        drop_last=False,
        # num_workers=1,
    )
    train_generator = train_loader
    # test_batch = next(iter(train_loader))[0].float().to(device)

    # validation set
    val_dataset = data.TensorDataset(torch.tensor(y_val))
    val_loader = data.DataLoader(
        dataset=val_dataset,
        batch_size=args.val_batch_size,
        shuffle=True,
        drop_last=False,
        # num_workers=1,
    )

    # test set
    test_dataset = data.TensorDataset(torch.tensor(y_test))
    test_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=args.val_batch_size,
        shuffle=False,
        drop_last=False,
        # num_workers=1,
    )

    features = y.shape[1]

    def create_linear_transform():
        if args.linear_transform_type == "permutation":
            return transforms.RandomPermutation(features=features)
        elif args.linear_transform_type == "lu":
            return transforms.CompositeTransform(
                [
                    transforms.RandomPermutation(features=features),
                    transforms.LULinear(features, identity_init=True),
                ]
            )
        elif args.linear_transform_type == "svd":
            return transforms.CompositeTransform(
                [
                    transforms.RandomPermutation(features=features),
                    transforms.SVDLinear(
                        features, num_householder=10, identity_init=True
                    ),
                ]
            )
        else:
            raise ValueError

    def create_base_transform(i):
        if args.base_transform_type == "affine-coupling":
            return transforms.AffineCouplingTransform(
                mask=utils.create_alternating_binary_mask(features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                    in_features=in_features,
                    out_features=out_features,
                    hidden_features=args.hidden_features,
                    context_features=None,
                    num_blocks=args.num_transform_blocks,
                    activation=F.relu,
                    dropout_probability=args.dropout_probability,
                    use_batch_norm=args.use_batch_norm,
                ),
            )
        elif args.base_transform_type == "quadratic-coupling":
            return transforms.PiecewiseQuadraticCouplingTransform(
                mask=utils.create_alternating_binary_mask(features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                    in_features=in_features,
                    out_features=out_features,
                    hidden_features=args.hidden_features,
                    context_features=None,
                    num_blocks=args.num_transform_blocks,
                    activation=F.relu,
                    dropout_probability=args.dropout_probability,
                    use_batch_norm=args.use_batch_norm,
                ),
                num_bins=args.num_bins,
                tails="linear",
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.apply_unconditional_transform,
            )
        elif args.base_transform_type == "rq-coupling":
            return transforms.PiecewiseRationalQuadraticCouplingTransform(
                mask=utils.create_alternating_binary_mask(features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                    in_features=in_features,
                    out_features=out_features,
                    hidden_features=args.hidden_features,
                    context_features=None,
                    num_blocks=args.num_transform_blocks,
                    activation=F.relu,
                    dropout_probability=args.dropout_probability,
                    use_batch_norm=args.use_batch_norm,
                ),
                num_bins=args.num_bins,
                tails="linear",
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.apply_unconditional_transform,
            )
        elif args.base_transform_type == "affine-autoregressive":
            return transforms.MaskedAffineAutoregressiveTransform(
                features=features,
                hidden_features=args.hidden_features,
                context_features=None,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm,
            )
        elif args.base_transform_type == "quadratic-autoregressive":
            return transforms.MaskedPiecewiseQuadraticAutoregressiveTransform(
                features=features,
                hidden_features=args.hidden_features,
                context_features=None,
                num_bins=args.num_bins,
                tails="linear",
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm,
            )
        elif args.base_transform_type == "rq-autoregressive":
            return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
                features=features,
                hidden_features=args.hidden_features,
                context_features=None,
                num_bins=args.num_bins,
                tails="linear",
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm,
            )
        else:
            raise ValueError

    def create_transform():
        transform = transforms.CompositeTransform(
            [
                transforms.CompositeTransform(
                    [create_linear_transform(), create_base_transform(i)]
                )
                for i in range(args.num_flow_steps)
            ]
            + [create_linear_transform()]
        )
        return transform

    # create model
    distribution = distributions.StandardNormal((features,))
    transform = create_transform()
    flow = flows.Flow(transform, distribution).to(device)

    n_params = utils.get_num_parameters(flow)
    print("There are {} trainable parameters in this model.".format(n_params))

    # create optimizer
    optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate)
    if args.anneal_learning_rate:
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.num_training_steps, 0
        )
    else:
        scheduler = None

    timestamp = cutils.get_timestamp()
    # if cutils.on_cluster():
    #     timestamp += "||{}".format(os.environ["SLURM_JOB_ID"])
    os.makedirs(os.path.join(cutils.get_log_root(), args.dataset_name), exist_ok=True)
    log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
    os.makedirs(log_dir, exist_ok=True)
    filename = os.path.join(log_dir, "config.json")
    with open(filename, "w") as file:
        json.dump(vars(args), file)

    def early_validate(
        rho_lengths_opt,
        y_test,
        vn_perm=None,
        y_perm=None,
        bern=None,
        d_perm_inds=None,
        n_perm_inds=None,
        helper=None,
    ):
        running_val_log_density = 0
        for val_batch in y_test:
            log_density_val = rho_lengths_opt.log_prob(
                val_batch[0].float().to(device).detach()
            )
            mean_log_density_val = torch.sum(log_density_val).detach()
            running_val_log_density += mean_log_density_val

        return -running_val_log_density / len(y_test.dataset)

    early_stopping_wrapper = EarlyStopping(
        early_validate,
        patience=wandb.config.model["patience"],
        y_val=val_loader,
        miniter=wandb.config.model["miniter"],
    )

    start = time()
    tbar = tqdm(range(args.num_training_steps))
    best_val_score = -1e10
    for step in tbar:
        batch = next(iter(train_loader))[0].to(device).float()

        flow.train()
        if args.anneal_learning_rate:
            scheduler.step(step)
        optimizer.zero_grad()

        log_density = flow.log_prob(batch)
        loss = -torch.mean(log_density)
        loss.backward()
        if args.grad_norm_clip_value is not None:
            clip_grad_norm_(flow.parameters(), args.grad_norm_clip_value)
        optimizer.step()

        if (step + 1) % args.monitor_interval == 0:
            # wandb.log({"train_loglik": -loss.item()})
            flow.eval()

            with torch.no_grad():
                # compute validation score
                if early_stopping_wrapper.callback(flow):
                    flow = early_stopping_wrapper.best_params
                    break

                running_val_log_density = -early_stopping_wrapper.last_score
            if running_val_log_density > best_val_score:
                best_val_score = running_val_log_density
                path = os.path.join(
                    cutils.get_checkpoint_root(),
                    "{}-best-val-{}.t".format(args.dataset_name, timestamp),
                )
                torch.save(flow.state_dict(), path)

        # if (step + 1) % args.monitor_interval2 == 0:

        #     # compute reconstruction
        #     with torch.no_grad():
        #         test_batch_noise = flow.transform_to_noise(test_batch)
        #         test_batch_reconstructed, _ = flow._transform.inverse(test_batch_noise)
        #     errors = test_batch - test_batch_reconstructed
        #     max_abs_relative_error = torch.abs(errors / test_batch).max()
        #     average_abs_relative_error = torch.abs(errors / test_batch).mean()
        #     wandb.log({"max-abs-relative-error": max_abs_relative_error})
        #     wandb.log({"average-abs-relative-error": average_abs_relative_error})

        #     summaries = {
        #         "val": running_val_log_density,
        #         "best-val": best_val_score,
        #         "neg-max-abs-relative-error": -max_abs_relative_error,
        #         "neg-average-abs-relative-error": -average_abs_relative_error,
        #     }
        #     for summary, value in summaries.items():
        #         wandb.log({summary: value})
    end = time()
    print("time for {} = {}".format(cfg.model.model_class, end - start))
    wandb.log({"fit_time": end - start})

    # load best val model
    path = os.path.join(
        cutils.get_checkpoint_root(),
        "{}-best-val-{}.t".format(args.dataset_name, timestamp),
    )
    os.makedirs(cutils.get_checkpoint_root(), exist_ok=True)

    try:
        flow.load_state_dict(torch.load(path))
    except:
        print("no best model found")
        pass
    flow.eval()

    # # calculate log-likelihood on test set
    # with torch.no_grad():
    #     log_likelihood = torch.Tensor([]).to(device)
    #     for batch in tqdm(test_loader):
    #         batch = batch[0].to(device).float()
    #         log_density = flow.log_prob(batch)
    #         log_likelihood = torch.cat([log_likelihood, log_density])
    # path = os.path.join(
    #     log_dir,
    #     "{}-{}-log-likelihood.npy".format(args.dataset_name, args.base_transform_type),
    # )
    # np.save(path, utils.tensor2numpy(log_likelihood))
    # mean_log_likelihood = log_likelihood.mean()
    # std_log_likelihood = log_likelihood.std()

    # save log-likelihood
    # s = "Final score for {}: {:.2f} +- {:.2f}".format(
    #     args.dataset_name.capitalize(),
    #     mean_log_likelihood.item(),
    #     2 * std_log_likelihood.item() / np.sqrt(len(test_dataset)),
    # )
    # print(s)
    # filename = os.path.join(log_dir, "test-results.txt")
    # with open(filename, "w") as file:
    #     file.write(s)

    return flow
