import torch
import time
import numpy as np
import argparse
import warnings
from eval_model import eval_model
from data.dataloader import get_dataloader
from utils.setup_utils import set_seed, setup_wandb, get_model, LrScheduler, setup_direcs
import wandb
import gc
from eval.eval_metrics import get_recon_loss
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_model(args):
    # fix random seed
    seed = set_seed(args.seed)
    print("Seed: ", seed)

    # get directory for model checkpoints
    model_dir = setup_direcs(args, seed)

    # setup weights and biases
    if args.local == 0:
        setup_wandb(args, seed)

    # load data
    train_loader, _, _ = get_dataloader(args)

    # create model
    model = get_model(args)

    print("Num Parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # get lr scheduler
    # remember to update when loading models
    #lr_scheduler = LrScheduler(args.lr, args.warmup_lr_steps, args.decay_rate, args.decay_steps)
    lr_scheduler_1 = LrScheduler(args.lr, args.warmup_lr_steps, args.decay_rate, args.decay_steps)
    lr_scheduler_2 = LrScheduler(1e-4, 30000, args.decay_rate, args.decay_steps)
    alpha_scheduler = LrScheduler(args.alpha, args.warmup_alpha_steps, .2, args.decay_steps_reg)

    glob_it = 0
    # set global iteration to load iteration if we're loading a model
    if args.load_seed > 0:
        glob_it = args.load_iter

    # set optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_scheduler_1.get_current_lr(glob_it))

    # load model and optimizer if specified
    #if args.load_seed > 0:
        #load_dir = setup_direcs(args, args.load_seed, load=True)
        #model.load_state_dict(torch.load(load_dir + "_iter_" + str(glob_it) + "_model_state_dict.pth"))
        #optimizer.load_state_dict(torch.load(load_dir + "_iter_" + str(glob_it) + "_optimizer_state_dict.pth"))
    #else:
        #glob_it = 0

    if args.test == 1:
        torch.cuda.empty_cache()
        gc.collect()
        eval_model(args, model, validation=False)
        wandb.finish()
        return

    # train loop
    b_it = 0
    model.train()
    while glob_it < args.num_iters:
        b_it += 1
        glob_it += 1
        optimizer.zero_grad()

        # get data
        x, _ = next(iter(train_loader))
        x = x.to(device)

        # get reconstructed latents and observations
        zh, xh = model(x)

        # reconstruction loss
        loss = args.sigma * get_recon_loss(x, xh)

        # kl loss
        if args.beta > 0:
            beta = args.beta
            loss += beta * model.dkl

        # interaction loss
        if args.decoder == "transformer" and args.alpha > 0:
            _, interac_dec_cross_att = model.decoder.transformer.compute_interaction()
            if glob_it < args.warmup_alpha_steps:
                alpha = alpha_scheduler.get_current_lr(glob_it)
            else:
                alpha = args.alpha
            loss += alpha * interac_dec_cross_att

        loss.backward()
        optimizer.step()

        # Update learning rate

        if glob_it < args.warmup_lr_steps:
            new_lr = lr_scheduler_1.get_current_lr(glob_it)
        elif glob_it > 30000:
            new_lr = lr_scheduler_2.get_current_lr(glob_it)
        else:
            new_lr = args.lr

        #new_lr = lr_scheduler.get_current_lr(glob_it)

        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr

        # eval model
        if glob_it % args.eval_iter == 0 or glob_it == args.num_iters:
            torch.cuda.empty_cache()
            gc.collect()
            b_it = 0
            torch.save(model.state_dict(),
                       model_dir + "_iter_" + str(glob_it) + "_model_state_dict.pth")
            torch.save(optimizer.state_dict(),
                       model_dir + "_iter_" + str(glob_it) + "_optimizer_state_dict.pth")
            # run eval
            eval_model(args, model)
            torch.cuda.empty_cache()
            gc.collect()
            model.train()

    if args.local == 0:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str)
    parser.add_argument("--encoder", type=str, default="transformer")
    parser.add_argument("--decoder", type=str, default="transformer")
    parser.add_argument("--num_slots", type=int)
    parser.add_argument("--slot_dim", type=int)
    parser.add_argument("--proj_dim", type=int, default="516")
    parser.add_argument("--query_dim", type=int, default="180")
    parser.add_argument("--proj_layers", type=int, default="2")
    parser.add_argument("--sigma", type=float, default="1.")
    parser.add_argument("--beta", type=float, default="0.")
    parser.add_argument("--alpha", type=float, default="0.")
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--seed", type=int, default="0")
    parser.add_argument("--lr", type=float)
    parser.add_argument("--warmup_lr_steps", type=int, default="10000")
    parser.add_argument("--warmup_alpha_steps", type=int, default="0")
    parser.add_argument("--warmup_beta_steps", type=int, default="0")
    parser.add_argument("--decay_rate", type=float, default="0.1")
    parser.add_argument("--decay_steps", type=float, default="500000")
    parser.add_argument("--decay_steps_reg", type=float, default="250000")
    parser.add_argument("--num_iters", type=int, default="500000")
    parser.add_argument("--eval_iter", type=int, default="10000")
    parser.add_argument("--local", type=int, default="0")
    parser.add_argument("--load_seed", type=int, default="-1")
    parser.add_argument("--load_iter", type=int, default="0")
    parser.add_argument("--test", type=int, default="0")
    args = parser.parse_args()

    if args.data == "spriteworld_4":
        args.encoder = "transformer"
        args.decoder = "transformer"
        args.num_slots = 5
        args.slot_dim = 32
        args.batch_size = 64
        args.lr = 5e-4
        args.sigma = 5.
        args.alpha = 0.5
        args.beta = 0.5

    train_model(args)
