import pickle
import time

import torch
from sklearn.preprocessing import OrdinalEncoder
from torch.utils.tensorboard import SummaryWriter
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm

from experiments.models.highres.cdtd_model import HighResCDTDModel
from experiments.models.highres.flow_model import HighResFlowModel
from experiments.models.lowres.encoder import Discretizer
from experiments.models.lowres.layers import LowResMLP
from experiments.models.lowres.lowres_diff import CatCDTD, CatFlow
from experiments.models.lowres.utils import FastTensorDataLoader, cycle
from experiments.utils import set_seeds

from .experiment import Experiment


def freeze(model, part="model"):
    if part == "model":
        # Freeze all parameters except for CatEncoder
        for name, param in model.named_parameters():
            if "encoder" not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    elif part == "embedding":
        # Freeze embedding layers only
        for name, param in model.named_parameters():
            if "encoder" in name:
                param.requires_grad = False
            else:
                param.requires_grad = True


class Experiment_TabCascade(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)

        #################################################
        # additional data pre-processing to get z_num
        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=False)
        x_cat_trn = train_loader.data[0]
        x_num_trn = train_loader.data[1]

        if self.config.highres.model.get("condition_on_z", True):
            # remove missingness indicator (used for other models) from cat features
            x_cat_trn_no_miss = x_cat_trn[:, : len(self.data_processor.cat_cols)]
        else:
            x_cat_trn_no_miss = x_cat_trn

        # encode X_num into Z_num
        training_start_time = time.monotonic()
        self.encoder = Discretizer(
            x_num_trn,
            variant=self.config.data.encoder,
            seed=self.seed,
            k_max=self.config.data.k_max,
            max_depth=self.config.data.max_depth,
        )
        groups, mask = self.encoder.encode(x_num_trn)
        encoder_train_time = time.monotonic() - training_start_time

        if self.config.data.encoder == "gmm":
            # adjust means and remove those not appearing in the data (after hard clustering)
            for i in range(groups.shape[1]):
                vals = groups[:, i].unique()
                self.encoder.means[i] = self.encoder.means[i][vals]

            # train additional ordinal encoder for groups
            # as some components may never be the argmax and thus not appear in the data
            self.gmm_ord_enc = OrdinalEncoder()
            groups = self.gmm_ord_enc.fit_transform(groups.numpy())
            groups = torch.from_numpy(groups).long()

        # mean impute numerical values
        x_means = torch.nanmean(x_num_trn, dim=0)
        for i in range(x_num_trn.shape[1]):
            x_num_trn[:, i] = torch.nan_to_num(x_num_trn[:, i], nan=x_means[i])

        means = self.encoder.means
        stds = self.encoder.stds
        if self.config.highres.model.get("coupling", "change_x_0") == "change_x_1":
            # normalize X_num using group means and stds
            for i in range(x_num_trn.shape[1]):
                # for inflated values, only center, do not scale (so std = 1)
                stds[i][stds[i] == 0] = 1.0

                means_i = means[i][groups[:, i]]
                stds_i = stds[i][groups[:, i]]
                x_num_trn[:, i] = (x_num_trn[:, i] - means_i) / (stds_i + 1e-8)

        # construct new data loader
        batch_size = min(self.config.data.batch_size, x_num_trn.shape[0])
        train_loader = FastTensorDataLoader(
            x_cat_trn_no_miss,
            x_num_trn,
            groups,
            mask,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
        )

        # determine n_classes for X_cat and Z_num
        n_classes_cat = []
        proportions_cat = []
        n_sample = x_cat_trn_no_miss.shape[0]
        for i in range(x_cat_trn_no_miss.shape[1]):
            val, counts = x_cat_trn_no_miss[:, i].unique(return_counts=True)
            n_classes_cat.append(len(val))
            proportions_cat.append(counts / n_sample)
        n_classes_num = []
        proportions_num = []
        if self.config.highres.model.get("condition_on_z", True):
            for i in range(groups.shape[1]):
                val, counts = groups[:, i].unique(return_counts=True)
                n_classes_num.append(len(val))
                proportions_num.append(counts / n_sample)

        #################################################
        # setup low resolution model
        n_classes = n_classes_cat + n_classes_num
        proportions = proportions_cat + proportions_num

        cfg = self.config.lowres.model
        predictor = LowResMLP(
            n_classes,
            cfg.cat_emb_dim,
            cfg.mlp_emb_dim,
            cfg.mlp_n_layers,
            cfg.mlp_n_units,
            proportions,
            cfg.mlp_act,
        )

        if cfg.variant == "flow":
            self.lowres = CatFlow(
                predictor,
                n_classes,
                proportions,
                cfg.cat_emb_dim,
                cfg.sigma_min,
                cfg.sigma_max,
                cfg.sigma_data,
                cfg.normalize_by_entropy,
                cfg.timewarp_weight_low_noise,
                cfg.timewarp_variant,
                cat_emb_init_sigma=cfg.cat_emb_init_sigma,
                learn_noise_schedule=cfg.learn_noise_schedule,
                init_embs_zero=cfg.init_embs_zero,
                learn_latents=cfg.learn_latents,
                norm_dim=cfg.norm_dim,
            )
        elif cfg.variant == "cdtd":
            self.lowres = CatCDTD(
                predictor,
                n_classes,
                proportions,
                cfg.cat_emb_dim,
                cfg.sigma_min,
                cfg.sigma_max,
                cfg.sigma_data,
                cfg.normalize_by_entropy,
                cfg.timewarp_weight_low_noise,
                cfg.timewarp_variant,
                cfg.cat_emb_init_sigma,
            )

        self.lowres.to(self.device)
        self.ema_lowres = ExponentialMovingAverage(
            self.lowres.parameters(),
            decay=self.config.lowres.training.ema_decay,
        )

        opt_lowres = torch.optim.AdamW(
            self.lowres.parameters(),
            lr=self.config.lowres.training.lr,
            weight_decay=self.config.lowres.training.weight_decay,
            betas=self.config.lowres.training.betas,
        )

        if self.config.lowres.model.variant == "flow":
            scheduler_lowres = torch.optim.lr_scheduler.ReduceLROnPlateau(
                opt_lowres,
                mode="min",
                factor=0.9,
                patience=3,
                min_lr=1e-6,
            )

        #################################################
        # setup high resolution model

        if self.config.highres.model.get("variant", "flow") == "flow":
            self.highres = HighResFlowModel(
                means,
                stds,
                n_classes_cat,
                n_classes_num,
                self.config.highres.model.mlp_emb_dim,
                self.config.highres.model.mlp_n_layers,
                self.config.highres.model.mlp_n_units,
                self.config.highres.model.gamma_input_dim,
                self.config.highres.model.cat_emb_dim,
                coupling=self.config.highres.model.get("coupling", "change_x_0"),
                learn_time_schedule=self.config.highres.model.get("learn_time_schedule", True),
            )
        elif self.config.highres.model.variant == "cdtd":
            self.highres = HighResCDTDModel(
                means,
                stds,
                n_classes_cat,
                n_classes_num,
                self.config.highres.model.mlp_emb_dim,
                self.config.highres.model.mlp_n_layers,
                self.config.highres.model.mlp_n_units,
                self.config.highres.model.cat_emb_dim,
            )

        self.highres.to(self.device)
        num_params_lowres = sum(p.numel() for p in self.lowres.parameters())
        num_params_highres = sum(p.numel() for p in self.highres.parameters())
        print("Total parameters =", num_params_highres + num_params_lowres)
        print("Lowres parameters =", num_params_lowres)
        print("Highres parameters =", num_params_highres)

        self.ema_highres = ExponentialMovingAverage(
            self.highres.parameters(),
            decay=self.config.highres.training.ema_decay,
        )

        opt_highres = torch.optim.AdamW(
            self.highres.parameters(),
            lr=self.config.highres.training.lr,
            weight_decay=self.config.highres.training.weight_decay,
            betas=self.config.highres.training.betas,
        )

        if self.config.highres.model.get("variant", "flow") == "flow":
            scheduler_highres = torch.optim.lr_scheduler.ReduceLROnPlateau(
                opt_highres,
                mode="min",
                factor=0.9,
                patience=3,
                min_lr=1e-6,
            )

        #################################################
        # train loop

        train_loader = cycle(train_loader)
        step = 0
        n_inputs = 0
        lowres_loss_trn = 0
        lowres_loss_hist = []
        highres_loss_trn = 0
        highres_loss_hist = []
        # emb_grad_norm_hist = []
        monitor_grad_norm = self.config.lowres.training.freeze_emb
        writer = SummaryWriter(self.workdir / "tb")
        pbar = tqdm(total=self.config.lowres.training.num_steps_train)
        training_start_time = time.monotonic()

        while step < self.config.lowres.training.num_steps_train:
            # Linear warmup learning rate for lowres model
            if step < self.config.lowres.training.num_steps_warmup:
                lr = self.config.lowres.training.lr * (step + 1) / self.config.lowres.training.num_steps_warmup
                for param_group in opt_lowres.param_groups:
                    param_group["lr"] = lr

            # Linear decay for cdtd-based lowres model
            if (self.config.lowres.model.variant == "cdtd") and (step > self.config.lowres.training.num_steps_warmup):
                aux_step = step - self.config.lowres.training.num_steps_warmup
                rate = 1 - (
                    aux_step
                    / (self.config.lowres.training.num_steps_train - self.config.lowres.training.num_steps_warmup)
                )
                lr = self.config.lowres.training.lr * rate + 1e-6 * (1 - rate)
                for param_group in opt_lowres.param_groups:
                    param_group["lr"] = lr

            # Linear warmup learning rate for highres model
            if step < self.config.highres.training.num_steps_warmup:
                lr = self.config.highres.training.lr * (step + 1) / self.config.highres.training.num_steps_warmup
                for param_group in opt_highres.param_groups:
                    param_group["lr"] = lr

            # Linear decay for cdtd-based highres model
            if self.config.highres.model.get("variant", "flow") == "cdtd" and (
                step > self.config.highres.training.num_steps_warmup
            ):
                aux_step = step - self.config.highres.training.num_steps_warmup
                rate = 1 - (
                    aux_step
                    / (self.config.lowres.training.num_steps_train - self.config.highres.training.num_steps_warmup)
                )
                lr = self.config.highres.training.lr * rate + 1e-6 * (1 - rate)
                for param_group in opt_highres.param_groups:
                    param_group["lr"] = lr

            opt_lowres.zero_grad(set_to_none=True)
            opt_highres.zero_grad(set_to_none=True)

            batch = next(train_loader)
            x_cat, x_num, z_num, mask = (x.to(self.device) for x in batch)
            B = len(x_cat)
            n_inputs += B

            ################################
            # lowres model
            if self.config.highres.model.get("condition_on_z", True):
                lowres_input = torch.column_stack((x_cat, z_num))
            else:
                lowres_input = x_cat

            losses = self.lowres.loss_fn(lowres_input)
            train_loss_lowres = losses["train_loss"]
            train_loss_lowres.backward()

            # collect grad norm of embeddings
            if monitor_grad_norm:
                emb_grad_norm = 0
                emb_grad_norm += self.lowres.encoder.cat_emb.weight.grad.norm().item() ** 2
                emb_grad_norm += self.lowres.encoder.cat_bias.grad.norm().item() ** 2
                emb_grad_norm = emb_grad_norm ** (1.0 / 2)
                # emb_grad_norm_hist.append(emb_grad_norm)

                if (step > 0.1 * self.config.lowres.training.num_steps_train) and (emb_grad_norm < 0.01):
                    freeze(self.lowres, part="embedding")
                    print(f"Freezing embeddings at step {step}...")
                    monitor_grad_norm = False

            # update parameters
            if self.config.lowres.training.clip_grad:
                torch.nn.utils.clip_grad_norm_(self.lowres.parameters(), max_norm=1.0)
            opt_lowres.step()
            self.ema_lowres.update()
            lowres_loss_trn += train_loss_lowres.detach().item() * B

            ################################
            # highres model
            if not self.config.highres.model.get("condition_on_z", True):
                z_num = None
                mask = None
            train_loss_highres = self.highres.loss_fn(x_num, x_cat, z_num, mask)
            train_loss_highres.backward()

            if self.config.highres.training.clip_grad:
                torch.nn.utils.clip_grad_norm_(self.highres.parameters(), max_norm=1.0)
            opt_highres.step()
            self.ema_highres.update()
            highres_loss_trn += train_loss_highres.detach().item() * B

            ################################
            # bookkeeping and learning rate scheduling

            if step % self.config.lowres.training.log_steps == 0:
                lowres_loss_trn = lowres_loss_trn / n_inputs
                lowres_loss_hist.append(lowres_loss_trn)

                if self.config.lowres.model.variant == "flow":
                    scheduler_lowres.step(lowres_loss_trn)

                highres_loss_trn = highres_loss_trn / n_inputs
                highres_loss_hist.append(highres_loss_trn)

                if self.config.highres.model.get("variant", "flow") == "flow":
                    scheduler_highres.step(highres_loss_trn)

                train_dict = {"lowres": lowres_loss_trn, "highres": highres_loss_trn}
                pbar.set_postfix(
                    {"loss (lowres)": f"{lowres_loss_trn:.4f}", "loss (highres)": f"{highres_loss_trn:.4f}"},
                )
                for metric_name, metric_value in train_dict.items():
                    writer.add_scalar("losses/{}".format(metric_name), metric_value, global_step=step)
                lowres_loss_trn = 0
                highres_loss_trn = 0
                n_inputs = 0
            step += 1
            pbar.update(1)
        pbar.close()
        training_duration = time.monotonic() - training_start_time
        with (self.workdir / "loss_hist_lowres.pkl").open("wb") as f:
            pickle.dump(lowres_loss_hist, f)
        with (self.workdir / "loss_hist_highres.pkl").open("wb") as f:
            pickle.dump(highres_loss_hist, f)

        # copy EMA weights to the model
        self.ema_lowres.copy_to()
        self.lowres.eval()
        self.ema_highres.copy_to()
        self.highres.eval()

        if save_model:
            self.save_train_time(training_duration + encoder_train_time)
            self.save_model()

    def save_model(self):
        torch.save(self.lowres.state_dict(), self.workdir / "lowres.pt")
        torch.save(self.highres.state_dict(), self.workdir / "highres.pt")
        # cannot currently save the DT encoder due to using an R function for disttree
        if self.config.data.encoder == "gmm":
            torch.save(self.encoder, self.workdir / "encoder.pt")
            torch.save(self.gmm_ord_enc, self.workdir / "gmm_ord_enc.pt")

    def load_model(self):
        set_seeds(self.seed)

        # additional data pre-processing to get z_num
        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=False)
        x_cat_trn = train_loader.data[0]
        x_num_trn = train_loader.data[1]

        if self.config.highres.model.get("condition_on_z", True):
            # remove missingness indicator (used for other models) from cat features
            x_cat_trn_no_miss = x_cat_trn[:, : len(self.data_processor.cat_cols)]
        else:
            x_cat_trn_no_miss = x_cat_trn

        # encode X_num into Z_num
        if self.config.data.encoder == "gmm":
            self.encoder = torch.load(self.workdir / "encoder.pt", weights_only=False)
            self.gmm_ord_enc = torch.load(self.workdir / "gmm_ord_enc.pt", weights_only=False)
        else:
            self.encoder = Discretizer(
                x_num_trn,
                variant=self.config.data.encoder,
                seed=self.seed,
                k_max=self.config.data.k_max,
                max_depth=self.config.data.max_depth,
            )
        groups, _ = self.encoder.encode(x_num_trn)

        if self.config.data.encoder == "gmm":
            groups = self.gmm_ord_enc.transform(groups)
            groups = torch.from_numpy(groups).long()

        #########################
        # low resolution model
        cfg = self.config.lowres.model

        # determine n_classes for X_cat and Z_num
        n_classes_cat = []
        proportions_cat = []
        n_sample = x_cat_trn_no_miss.shape[0]
        for i in range(x_cat_trn_no_miss.shape[1]):
            val, counts = x_cat_trn_no_miss[:, i].unique(return_counts=True)
            n_classes_cat.append(len(val))
            proportions_cat.append(counts / n_sample)
        n_classes_num = []
        proportions_num = []
        if self.config.highres.model.get("condition_on_z", True):
            for i in range(groups.shape[1]):
                val, counts = groups[:, i].unique(return_counts=True)
                n_classes_num.append(len(val))
                proportions_num.append(counts / n_sample)

        n_classes = n_classes_cat + n_classes_num
        proportions = proportions_cat + proportions_num

        if cfg.variant == "arf":
            arf_workdir = self.workdir.parent / f"arf_lowres_{self.seed}"
            with (arf_workdir / "checkpoint.pkl").open("rb") as f:
                checkpoint = pickle.load(f)
                self.lowres = checkpoint["model"]
        else:
            predictor = LowResMLP(
                n_classes,
                cfg.cat_emb_dim,
                cfg.mlp_emb_dim,
                cfg.mlp_n_layers,
                cfg.mlp_n_units,
                proportions,
                cfg.mlp_act,
            )

            if cfg.variant == "flow":
                self.lowres = CatFlow(
                    predictor,
                    n_classes,
                    proportions,
                    cfg.cat_emb_dim,
                    cfg.sigma_min,
                    cfg.sigma_max,
                    cfg.sigma_data,
                    cfg.normalize_by_entropy,
                    cfg.timewarp_weight_low_noise,
                    cfg.timewarp_variant,
                    cat_emb_init_sigma=cfg.cat_emb_init_sigma,
                    learn_noise_schedule=cfg.learn_noise_schedule,
                    init_embs_zero=cfg.init_embs_zero,
                    learn_latents=cfg.learn_latents,
                    norm_dim=cfg.norm_dim,
                )
            elif cfg.variant == "cdtd":
                self.lowres = CatCDTD(
                    predictor,
                    n_classes,
                    proportions,
                    cfg.cat_emb_dim,
                    cfg.sigma_min,
                    cfg.sigma_max,
                    cfg.sigma_data,
                    cfg.normalize_by_entropy,
                    cfg.timewarp_weight_low_noise,
                    cfg.timewarp_variant,
                    cfg.cat_emb_init_sigma,
                )
            checkpoint = torch.load(self.workdir / "lowres.pt")
            self.lowres.load_state_dict(checkpoint)
            self.lowres.to(self.device)
            self.lowres.eval()

        #################################################
        # high resolution model
        means = self.encoder.means
        stds = self.encoder.stds
        if self.config.highres.model.get("coupling", "change_x_0") == "change_x_1":
            for i in range(x_num_trn.shape[1]):
                # for inflated values, only center, do not scale (so std = 1)
                stds[i][stds[i] == 0] = 1.0

        if self.config.highres.model.get("variant", "flow") == "flow":
            self.highres = HighResFlowModel(
                means,
                stds,
                n_classes_cat,
                n_classes_num,
                self.config.highres.model.mlp_emb_dim,
                self.config.highres.model.mlp_n_layers,
                self.config.highres.model.mlp_n_units,
                self.config.highres.model.gamma_input_dim,
                self.config.highres.model.cat_emb_dim,
                coupling=self.config.highres.model.get("coupling", "change_x_0"),
                learn_time_schedule=self.config.highres.model.get("learn_time_schedule", True),
            )
        elif self.config.highres.model.variant == "cdtd":
            self.highres = HighResCDTDModel(
                means,
                stds,
                n_classes_cat,
                n_classes_num,
                self.config.highres.model.mlp_emb_dim,
                self.config.highres.model.mlp_n_layers,
                self.config.highres.model.mlp_n_units,
                self.config.highres.model.cat_emb_dim,
            )

        checkpoint = torch.load(self.workdir / "highres.pt")
        self.highres.load_state_dict(checkpoint)
        self.highres.to(self.device)
        self.highres.eval()

        # save data for plotting of learned noise schedule
        if self.config.highres.model.get("variant", "flow") == "flow":
            x_cat = x_cat_trn_no_miss.to(self.device)
            z_num = groups.to(self.device)
            t_grid, g = self.highres.plot_gamma(x_cat, z_num)
            with (self.workdir / "g_mean.pkl").open("wb") as f:
                pickle.dump(g.mean(1), f)
            with (self.workdir / "g_var.pkl").open("wb") as f:
                pickle.dump(g.var(1), f)
            with (self.workdir / "g_t_grid.pkl").open("wb") as f:
                pickle.dump(t_grid, f)
            del x_cat
            del z_num

    def sample(self, num_samples, seed):
        # first sample low resolution information
        if self.config.lowres.model.variant == "arf":
            set_seeds(seed)
            x_low_gen = self.lowres.forge(n=num_samples)
            x_low_gen = torch.from_numpy(x_low_gen.values).long()
        else:
            x_low_gen = self.lowres.sample_data(
                num_samples,
                num_steps=self.config.lowres.model.generation_steps,
                batch_size=self.config.lowres.model.generation_batch_size,
                seed=seed,
                verbose=False,
            )
        if self.config.highres.model.get("condition_on_z", True):
            x_cat_gen = x_low_gen[:, : len(self.data_processor.cat_cols)]
            z_num_gen = x_low_gen[:, len(self.data_processor.cat_cols) :]
        else:
            x_cat_gen = x_low_gen
            z_num_gen = None

        # then sample high resolution information conditioned on low resolution
        x_num_gen = self.highres.sample_data(
            x_cat_gen,
            z_num_gen,
            num_steps=self.config.highres.model.generation_steps,
            batch_size=self.config.highres.model.generation_batch_size,
            seed=seed,
            verbose=False,
        )

        if self.config.highres.model.get("condition_on_z", True):
            # overwrite inflated / missing values in X_num using Z_num
            assert x_num_gen.shape == z_num_gen.shape
            if self.config.data.encoder == "gmm":
                z_num_gen_enc = self.gmm_ord_enc.inverse_transform(z_num_gen)
                z_num_gen_enc = torch.from_numpy(z_num_gen_enc).long()
            else:
                z_num_gen_enc = z_num_gen
            infl_mask, miss_mask = self.encoder.get_masks(z_num_gen_enc)

            # get groups means (= inflated value if var = 0)
            z_num_gen_means = (
                self.highres.get_group_means(z_num_gen.to(self.device) + self.highres.group_offset).squeeze(-1).cpu()
            )
            x_num_gen = torch.where(infl_mask, z_num_gen_means, x_num_gen)

            if miss_mask is not None:
                x_num_gen = torch.masked_fill(x_num_gen, miss_mask, torch.nan)

        return x_cat_gen.numpy(), x_num_gen.numpy()
