import os
import logging
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from transformers.optimization import get_constant_schedule_with_warmup
from tensorboardX import SummaryWriter
import pickle

import selfies as sf

from models.nn import MLPResDual
from models.transformer import TransformerNet
from models.marnet_match import MarNetModel
from utils.ising_utils import prepare_ising_data, LatticeIsingModel
from utils.data_utils import load_dataset
from utils.mar_utils_mol import gen_order
from utils.mol_utils import multiple_indices_to_string, string_to_int
from utils.constants import BIT_UNKNOWN_VAL
from utils.var_utils import VarEstimator
from utils.eval_mol import MolEvalModel

class Runner(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.writer = SummaryWriter(os.getcwd())
        self.device_id = 'cuda:{}'.format(cfg.local_rank)
        self.master_node = (self.cfg.local_rank == 0)
        self.distributed = (self.cfg.world_size > 1)
        self.fig1 = plt.figure()
        self.fig2 = plt.figure()
        if cfg.alpha_annealing > 0:
            self.alpha_end = cfg.alpha
            self.alpha_init = cfg.alpha_init
        if self.cfg.data == 'ising':
            self.score_model = LatticeIsingModel(cfg.ising_model.dim, cfg.ising_model.sigma, cfg.ising_model.bias)
            if cfg.eval_reverse_kl:
                if cfg.ising_model.samples_path is not None:
                    with open(cfg.ising_model.samples_path, 'rb') as f:
                        self.ising_samples = pickle.load(f).to(self.device_id)
                else:
                    self.ising_samples = self.score_model.generate_samples(cfg.ising_model.n_samples, cfg.ising_model.gt_steps)
                    self.ising_samples = self.ising_samples.to(self.device_id)
                    with open("{}/data.pkl".format(self.cfg.log_dir), 'wb') as f:
                        pickle.dump(self.ising_samples, f)
            self.train_loader, self.val_loader, self.test_loader = prepare_ising_data(self.score_model, cfg, distributed=self.distributed)
        elif self.cfg.data == 'MOSES':
            self.train_loader, self.val_loader, self.test_loader = load_dataset(cfg, distributed=self.distributed)
            self.score_model = MolEvalModel(cfg.alphabet, cfg.string_type, cfg.metric_name, cfg.target_value, self.cfg.tau)
        self.epoch = 0

        if cfg.arch == 'mlp_dual':
            self.nn = MLPResDual(cfg.nn.hidden_dim, cfg.K, cfg.L, cfg.nn.n_layers, cfg.nn.res)
        elif cfg.arch == 'transformer':
            self.nn = TransformerNet(
                num_src_vocab=(cfg.K + 1), # add one for mask token
                num_tgt_vocab= cfg.K, 
                embedding_dim=768,
                hidden_size=3072,
                nheads=12,
                n_layers=12,
                max_src_len=self.cfg.L,
                is_cls_token=self.cfg.nn.is_cls_token,
            )
        else:
            raise ValueError("Unknown model {}".format(cfg.arch))
        self.nn.to(self.device_id)
        logging.info(self.nn)

        init_samples = torch.randint(low=1, high=cfg.K+1, size=(cfg.batch_size, cfg.L)).float()
        self.marnet = MarNetModel(self.nn, self.score_model, init_samples, cfg)
        self.marnet.to(self.device_id)

        if self.distributed:
            self.marnet = torch.nn.parallel.DistributedDataParallel(self.marnet, device_ids=[cfg.local_rank], output_device=cfg.local_rank)
            self.marnet_module = self.marnet.module
        else:
            self.marnet_module = self.marnet

        self.clip_grad = self.cfg.clip_grad

        param_list = [{'params': self.marnet.net.parameters(), 'lr': self.cfg.lr},
                    {'params': self.marnet.LogZ, 'lr': self.cfg.zlr}]

        if self.cfg.arch == "transformer":
            self.optimizer = optim.AdamW(param_list, betas=(0.9, 0.99), eps=1e-08, weight_decay=0.05)
            self.scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=5000) # some cuda driver issues
        elif self.cfg.arch.startswith("mlp"):
            self.optimizer = optim.Adam(param_list)
            self.scheduler = StepLR(self.optimizer, step_size=cfg.lr_decay_step, gamma=0.1)
        else:
            raise ValueError("Unknown arch {}".format(self.cfg.arch))

        if self.cfg.load and (self.cfg.loadpath is not None):
            self.load(self.cfg.loadpath)

        self.save_every = self.cfg.save_every
        self.eval_every = 5

    def load(self, path):
        map_location = {"cuda:0": self.device_id}
        checkpoint = torch.load(path, map_location=map_location)
        self.marnet_module.load_state_dict(checkpoint['net'], strict=False)
        # self.optimizer.load_state_dict(checkpoint['optimizer'])
        print("loaded", flush=True)
    
    def train(self):
        print("training rank %u" % self.cfg.local_rank, flush=True)
        self.marnet.train()
        dataloader = self.train_loader

        it = 0
        best_kl_div = None
        while self.epoch < self.cfg.n_epochs:
            epoch_metrics = {
                'mb_loss': 0,
                'mb_loss_begin': 0,
                'mb_loss_end': 0,
                'count': 0,
            }

            bsz = 0
            accum, accumll = 0, 0.0

            if self.cfg.objective == 'KL':
                rand_order = gen_order(self.cfg.batch_size, self.cfg.L, self.device_id, gen_order=self.cfg.gen_order)
                with torch.no_grad():
                    self.marnet_module.samples = self.marnet.sample(rand_order, self.cfg.batch_size)

            self.marnet.train()

            pbar = tqdm(dataloader)
            pbar.set_description("Epoch {}: Training".format(self.epoch))
            for x, _ in pbar:
                x = x.cuda(device=self.device_id, non_blocking=True)
                y = self.score_model(x-1.0)
                x = x.squeeze(dim=1)
                loss, mb_loss, mb_loss_begin, mb_loss_end, logf_t, logp_x = self.marnet(x, y)
                loss.backward()

                count = x.shape[0]
                epoch_metrics['mb_loss'] += mb_loss.item() * count
                epoch_metrics['mb_loss_begin'] += mb_loss_begin.item() * count
                epoch_metrics['mb_loss_end'] += mb_loss_end.item() * count
                epoch_metrics['count'] += count

                bsz += x.shape[0]
                accum += x.shape[0]
                last_lr = self.scheduler.get_last_lr()

                if bsz >= 32 // self.cfg.world_size:
                    if self.clip_grad > 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(self.marnet.parameters(), self.clip_grad)
                        if total_norm > 1e4:
                            print("grad_norm is {}".format(total_norm))
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    bsz = 0

                if accum >= 5120 // self.cfg.world_size:
                    if self.master_node:
                        logging.info("Iter %u out of %u, mb-loss-begin: %.2f, mb-loss: %.2f, mb-loss-end: %.2f, loss: %.2f, logz: %.2f lr: %f"
                            % (it, len(dataloader), mb_loss_begin.item(), mb_loss.item(), mb_loss_end.item(), loss.item(), self.marnet_module.LogZ.item(), last_lr[0]))
                        logging.info("Iter %u out of %u, p_mean: %.2f, p_std: %.2f, f_t_mean: %.2f, f_t_std: %.2f, alpha: %.2e"
                            % (it, len(dataloader), logp_x.mean().item(), logp_x.std().item(), logf_t.mean().item(), logf_t.std().item(), self.cfg.alpha))
                        self.writer.add_scalar('Obj/mb_loss_begin', mb_loss_begin.item(), it + 1)
                        self.writer.add_scalar('Obj/mb_loss_end', mb_loss_end.item(), it + 1)
                        self.writer.add_scalar('Obj/mb_loss', mb_loss.item(), it + 1)
                        self.writer.add_scalar('Obj/loss', loss.item(), it + 1)
                        self.writer.add_scalar('Obj/logZ', self.marnet_module.LogZ.item(), it + 1)
                        self.writer.add_scalar('Obj/p_mean', logp_x.mean().item(), it + 1)
                        self.writer.add_scalar('Obj/p_std', logp_x.std().item(), it + 1)
                        self.writer.add_scalar('Obj/f_t_mean', logf_t.mean().item(), it + 1)
                        self.writer.add_scalar('Obj/f_t_std', logf_t.std().item(), it + 1)
                        accum = 0
                        accumll = 0.0

                pbar.set_postfix({"mb_loss_begin": f"{mb_loss_begin.item():.2f}",\
                    "mb": f"{mb_loss.item():.2e}", "mb_loss_end": f"{mb_loss_end.item():.2f}",\
                    "loss": f"{loss.item():.2e}", "logZ": f"{self.marnet_module.LogZ.item():.2f}",\
                    "p_mean": f"{logp_x.mean().item():.2f}", "p_std": f"{logp_x.std().item():.2f}",\
                    "f_t_mean": f"{logf_t.mean().item():.2f}", "f_t_std": f"{logf_t.std().item():.2f}",\
                    "lr": f"{last_lr[0]:.2e}"})
                it += 1

            if self.epoch % self.eval_every == 0:
                with torch.no_grad():
                    metric_tensor = torch.tensor([  epoch_metrics['mb_loss'], epoch_metrics['mb_loss_begin'],\
                        epoch_metrics['mb_loss_end'], epoch_metrics['count'] ] )
                    if self.distributed:
                        torch.distributed.reduce(metric_tensor, dst=0)

                if self.master_node:
                    kl_div_est = self.eval_kl(self.cfg.eval_reverse_kl)
                    if best_kl_div is None:
                        best_kl_div = kl_div_est
                    if self.cfg.save_model:
                        if kl_div_est <= best_kl_div:
                            best_kl_div = kl_div_est
                            states = {
                                'net': self.marnet_module.state_dict(),
                                # 'optimizer': self.optimizer.state_dict(),
                                'epoch': self.epoch + 1,
                                'L': self.cfg.L,
                                'K': self.cfg.K,
                            }
                            torch.save(states, os.path.join(self.cfg.model_dir, 'checkpoint.pth'))
                test_epoch_metric_tensor = self.test()

                if self.master_node:
                    metric_tensor[0] /= metric_tensor[-1]
                    metric_tensor[1] /= metric_tensor[-1]
                    metric_tensor[2] /= metric_tensor[-1]
                    self.writer.add_scalar('Loss/train_mb_loss', metric_tensor[0], self.epoch)
                    self.writer.add_scalar('Loss/train_mb_loss_begin', metric_tensor[1], self.epoch)
                    self.writer.add_scalar('Loss/train_mb_loss_end', metric_tensor[2], self.epoch)
                    self.writer.add_scalar('Loss/test_mb_loss', test_epoch_metric_tensor[0], self.epoch)
                    self.writer.add_scalar('Loss/test_mb_loss_begin', test_epoch_metric_tensor[1], self.epoch)
                    self.writer.add_scalar('Loss/test_mb_loss_end', test_epoch_metric_tensor[2], self.epoch)
                    self.writer.add_scalar('Loss/test_mb_diff', test_epoch_metric_tensor[3], self.epoch)
                    self.writer.add_scalar('Loss/test_mb_diff_var', test_epoch_metric_tensor[4], self.epoch)
                    logging.info("Epoch %u out of %u, test mb_loss: %.2f, test mb_loss_begin: %.2f, test mb_loss_end: %.2f" % (
                        self.epoch, self.cfg.n_epochs, test_epoch_metric_tensor[0], test_epoch_metric_tensor[1], test_epoch_metric_tensor[2]))
            self.scheduler.step()
            self.epoch += 1

    def eval_kl(self, reverse=False):
        self.marnet.eval()
        time_marg, time_arm = [], []
        if self.cfg.eval_reverse_kl:
            samples = self.ising_samples + 1.0
        else:
            rand_order_gen = gen_order(self.cfg.batch_size, self.cfg.L, self.device_id, gen_order=self.cfg.gen_order)
            with torch.no_grad():
                samples = self.marnet.sample(rand_order_gen, self.cfg.batch_size)
        with torch.no_grad():
            samples_to_plot = samples[:100,]
            for i in range(10):
                samples_logp_marg, _ = self.marnet.eval_ll(samples)
                samples_logp = self.marnet.est_logp(samples, 1, self.cfg.gen_order)
            samples_logf_true = self.marnet.score_model(samples - 1.0) # convert back to [0:K-1] first
        if self.cfg.eval_reverse_kl:
            kl_div = - samples_logp.mean()
        else:
            kl_div = (samples_logp - samples_logf_true).mean()
        if self.cfg.plot_samples and self.epoch % self.cfg.plot_every == 0:
            plt.figure(self.fig1.number)
            sns.kdeplot(samples_logf_true.cpu().numpy(), fill=True)
            self.fig1.savefig(os.path.join(self.cfg.log_dir, 'samples_epoch{}.png'.format(self.epoch)))
            plt.figure(self.fig2.number)
            data_scores = self.score_model.get_scores(samples - 1.0)
            sns.kdeplot(data_scores.cpu().numpy(), fill=True)
            self.fig2.savefig(os.path.join(self.cfg.log_dir, 'data_scores_epoch{}.png'.format(self.epoch)))
            if self.cfg.data == 'ising':
                file_name = os.path.join(self.cfg.log_dir, 'samples_vis_epoch{}.png'.format(self.epoch))
                torchvision.utils.save_image(
                    samples_to_plot.float().reshape(samples_to_plot.shape[0], 1, self.cfg.ising_model.dim, self.cfg.ising_model.dim), file_name, normalize=True, nrow=int(samples_to_plot.shape[0] ** .5))
            with open("{}/model_samples.pkl".format(self.cfg.log_dir), 'wb') as f:
                pickle.dump(samples.cpu(), f)
            with open("{}/model_samples_scores.pkl".format(self.cfg.log_dir), 'wb') as f:
                pickle.dump(data_scores.cpu(), f)
        
        self.writer.add_scalar('Loss/p_mean', samples_logp.mean().item(), self.epoch)
        self.writer.add_scalar('Loss/p_std', samples_logp.std().item(), self.epoch)
        self.writer.add_scalar('Loss/f_t_mean', samples_logf_true.mean().item(), self.epoch)
        self.writer.add_scalar('Loss/f_t_std', samples_logf_true.std().item(), self.epoch)
        self.writer.add_scalar('Loss/test_KL_div', kl_div, self.epoch)
        # log f_t_mean, f_st_std, kl_div
        logging.info("test KL_div/nll: %.2f, p_mean:%.2f, p_std: %.2f, f_t_mean: %.2f, f_t_std: %.2f" % (
            kl_div, samples_logp.mean(), samples_logp.std(), samples_logf_true.mean(), samples_logf_true.std()))
        return kl_div

    def eval_variance(self, repeat=100):
        self.marnet.eval()
        rand_order_gen = gen_order(self.cfg.batch_size, self.cfg.L, self.device_id, gen_order=self.cfg.gen_order)
        # with torch.no_grad():
        #     samples = self.marnet.sample(rand_order_gen, self.cfg.batch_size)
        torch.manual_seed(100)
        samples = torch.randint(low=1, high=self.cfg.K+1, size=(self.cfg.batch_size, self.cfg.L)).float().to(self.device_id)
        loss = self.marnet.eval_loss(samples, use_marg=False)
        loss.backward()
        estimator = VarEstimator(self.marnet.net)
        self.optimizer.zero_grad()
        for i in range(repeat):
            loss = self.marnet.eval_loss(samples, use_marg=True)
            loss.backward()
            estimator.update(self.marnet.net)
            self.optimizer.zero_grad()
        var = estimator.get_variance()
        bias = estimator.get_bias()
        logging.info("======>variance: %.4e" % var)
        logging.info("======>bias: %.4e" % bias)
        self.optimizer.zero_grad()
        self.writer.add_scalar('Grad/grad_variance', var, self.epoch)
        self.writer.add_scalar('Grad/grad_bias', bias, self.epoch)
        return 

    def test(self):
        self.marnet.eval()
        dataloader = self.test_loader
        mode = 'test'

        epoch_metrics = {
            'mb_loss': 0,
            'mb_loss_begin': 0,
            'mb_loss_end': 0,
            'mb_diff': 0,
            'mb_diff_var': 0,
            'count': 0,
        }
        
        pbar = tqdm(dataloader)
        pbar.set_description("Testing calculating likelihood")
        it = 0
        for x, y in pbar:
            x = x.cuda(device=self.device_id, non_blocking=True)
            x = x.squeeze(dim=1)
            y = self.score_model(x-1.0)
            with torch.no_grad():
                loss, mb_loss, mb_loss_begin, mb_loss_end, logf_t, logp_x = self.marnet(x, y)
            if hasattr(pbar, "set_postfix"):
                pbar.set_postfix({
                    "mb_loss": f"{mb_loss:.2f}", "mb_loss_begin": f"{mb_loss_begin:.2f}",\
                    "mb_loss_end": f"{mb_loss_end:.2f}", "loss": f"{loss:.2f}",\
                    "p_mean": f"{logp_x.mean().item():.2f}", "p_std": f"{logp_x.std().item():.2f}",\
                    "f_t_mean": f"{logf_t.mean().item():.2f}", "f_t_std": f"{logf_t.std().item():.2f}"   
                })
            it += 1
            if it==self.cfg.eval.num_batches:
                break

            count = x.shape[0]
            epoch_metrics['mb_loss'] += mb_loss.item() * count
            epoch_metrics['mb_loss_begin'] += mb_loss_begin.item() * count
            epoch_metrics['mb_loss_end'] += mb_loss_end.item() * count
            epoch_metrics['mb_diff'] += torch.abs(logp_x - logf_t).mean().item() * count
            epoch_metrics['mb_diff_var'] += torch.var(logp_x - logf_t).item() * count
            epoch_metrics['count'] += count

        with torch.no_grad():
            metric_tensor = torch.tensor( [epoch_metrics['mb_loss'], epoch_metrics['mb_loss_begin'],\
                epoch_metrics['mb_loss_end'], epoch_metrics['mb_diff'], epoch_metrics['mb_diff_var'],\
                epoch_metrics['count'] ] )
            if self.distributed:
                torch.distributed.reduce(metric_tensor, dst=0)

            if self.master_node:
                for i in range(metric_tensor.shape[0] - 1):
                    metric_tensor[i] /= metric_tensor[-1]
                logging.info("%s count, %u mb_loss: %.4f, mb_loss_begin: %.4f, mb_loss_end: %.4f" % (
                    mode, metric_tensor[-1], metric_tensor[0], metric_tensor[1], metric_tensor[2]))

        return metric_tensor
    
    def generate(self):
        self.marnet.eval()
        rand_order_gen = gen_order(self.cfg.gen_num_samples, self.cfg.L, self.device_id, gen_order=self.cfg.gen_order)
        with torch.no_grad():
            samples = self.marnet.sample(rand_order_gen, self.cfg.gen_num_samples)
            samples = samples.cuda(device=self.device_id, non_blocking=True)
            with torch.no_grad():
                logp_x, _ = self.marnet.net(samples) # (B)
                samples_logp = self.marnet.est_logp(samples, 1, self.cfg.gen_order)
            samples_logf_true = self.marnet.score_model(samples - 1.0)
            kl_div = (samples_logp - samples_logf_true).mean()
        logging.info("test KL div: %.2f, f_t_mean:%.2f, f_t_std: %.2f, logp_est_mean: %.2f, logp_est_std: %.2f" % (
            kl_div, samples_logf_true.mean(), samples_logf_true.std(), samples_logp.mean(), samples_logp.std()))
        logging.info("logp_x_mean: %.2f, logp_x_std: %.2f" % (logp_x.mean().item(), logp_x.std().item()))
        data_scores = self.score_model.get_scores(samples - 1.0)
        save_path = os.path.join(self.cfg.log_dir, 'samples.png')
        self.score_model.plot_scores(data_scores.cpu().numpy(), save_path)
        with open("{}/model_samples.pkl".format(self.cfg.log_dir), 'wb') as f:
            pickle.dump(samples.cpu(), f)
        with open("{}/model_samples_scores.pkl".format(self.cfg.log_dir), 'wb') as f:
            pickle.dump(data_scores.cpu(), f)
        if self.cfg.data =='ising' and self.cfg.eval_reverse_kl:
            samples = self.ising_samples + 1.0
            with torch.no_grad():
                samples_logp = self.marnet.est_logp(samples, 1, self.cfg.gen_order)
                samples_scores = self.score_model.get_scores(samples - 1.0)             
                samples_logf_true = self.marnet.score_model(samples - 1.0) # convert back to [0:K-1] first
            save_path = os.path.join(self.cfg.log_dir, 'ising_samples_scores.pkl')
            with open(save_path, 'wb') as f:
                pickle.dump(samples_scores.cpu(), f)
            nll = - samples_logp.mean()
            logging.info("test nll: %.2f, f_t_mean:%.2f, f_t_std: %.2f" % (nll, samples_logf_true.mean(), samples_logf_true.std()))


    def generate_mols(self):
        self.marnet.eval()
        mode = 'generate'
        num_samples = self.cfg.gen_num_samples
        iters = num_samples // self.cfg.generate_batch_size
        # open a file to write generated samples
        os.makedirs(self.cfg.gen_dir, exist_ok=True)
        pbar = tqdm(range(iters))
        for i in pbar:
            with torch.no_grad():
                if self.cfg.conditional:
                    # convert from smiles to indices
                    string_selfies = sf.encoder(self.cfg.string_example)
                    string_smiles = sf.decoder(string_selfies)
                    x_cond = string_to_int(string_selfies, self.cfg.string_type, self.cfg.L, self.cfg.alphabet)
                    x_cond = torch.tensor(x_cond, dtype=torch.long, device=self.device_id).unsqueeze(0) + 1 # add 1 to include unknown as 0                       
                    x_cond[:, self.cfg.start:self.cfg.end] = BIT_UNKNOWN_VAL                 
                    rand_order_gen = gen_order(
                        self.cfg.generate_batch_size, self.cfg.end-self.cfg.start, self.device_id, gen_order=self.cfg.gen_order
                    )
                    x_gen = self.marnet_module.cond_sample(x_cond, rand_order_gen, self.cfg.generate_batch_size)
                else:
                    rand_order_gen = gen_order(
                        self.cfg.generate_batch_size, self.cfg.L, self.device_id, gen_order=self.cfg.gen_order
                    )
                    x_gen = self.marnet_module.sample(rand_order_gen, self.cfg.generate_batch_size)
                logp_ebm, log_z = self.marnet_module.eval_ll(x_gen)
                logp = self.marnet_module.est_logp(x_gen, self.cfg.eval.mc_ll, self.cfg.gen_order) # (B,)
                samples_logf_true = self.marnet.score_model(x_gen - 1.0)
                kl_div = (logp - samples_logf_true).mean()
                logging.info("logp_ebm: %.4f logp mean: %.4f std: %.4f" % ((logp_ebm - log_z).mean(), logp.mean(), logp.std()))
                logging.info("test KL div: %.4f, f_t_mean:%.4f, f_t_std: %.4f" % (
                    kl_div, samples_logf_true.mean(), samples_logf_true.std()))
            x_gen = (x_gen-1.0).int().cpu().numpy().tolist()
            if self.cfg.string_type == 'SELFIES':
                x_gen_selfies = multiple_indices_to_string(x_gen, self.cfg.alphabet)
                with open(os.path.join(self.cfg.gen_dir, 'generated_samples_selfies.txt'), 'w') as f:
                    f.write(string_selfies + '\n')
                    for j in range(len(x_gen_selfies)):
                        f.write(x_gen_selfies[j] + '\n')
                x_gen_smiles = list(map(sf.decoder, x_gen_selfies))
            elif self.cfg.string_type == 'SMILES':
                x_gen_smiles = multiple_indices_to_string(x_gen, self.cfg.alphabet)
            with open(os.path.join(self.cfg.gen_dir, 'generated_samples_smiles.txt'), 'w') as f:
                f.write(string_smiles + '\n')
                for j in range(len(x_gen_smiles)):
                    f.write(x_gen_smiles[j] + '\n')
            pbar.set_postfix({"generated samples": f"{(i+1)*self.cfg.generate_batch_size}"})