from .experiment import Experiment


# from .models.cdtd import CDTD
# from .utils import set_seeds

# from .models.lowres.layers import LowResMLP
# from .utils import set_seeds
# from .models.lowres.utils import FastTensorDataLoader, cycle

from experiments.utils import set_seeds
from experiments.models.lowres.utils import FastTensorDataLoader, cycle
from experiments.models.lowres.encoder import Discretizer
from experiments.models.lowres.lowres_diff import CatFlow, CatCDTD
from experiments.models.lowres.layers import LowResMLP
from experiments.models.highres.flow_model import HighResFlowModel
from sklearn.preprocessing import OrdinalEncoder

from torch_ema import ExponentialMovingAverage
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import time

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import polars as pl
import pickle
import logging


def freeze(model, part='model'):
    
    if part == 'model':
        # Freeze all parameters except for CatEncoder
        for name, param in model.named_parameters():
            if "encoder" not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    elif part == 'embedding':
        # Freeze embedding layers only
        for name, param in model.named_parameters():
            if "encoder" in name:
                param.requires_grad = False
            else:
                param.requires_grad = True


class Experiment_TabCascade(Experiment):
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        #################################################
        # additional data pre-processing to get z_num
        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=False)
        X_cat_trn = train_loader.data[0]
        X_num_trn = train_loader.data[1]
        
        # remove missingness indicator (used for other models) from cat features
        X_cat_trn_no_miss = X_cat_trn[:, :len(self.data_processor.cat_cols)]
        
        # encode X_num into Z_num
        training_start_time = time.monotonic()
        self.encoder = Discretizer(X_num_trn, variant=self.config.data.encoder, seed=self.seed,
                              k_max=self.config.data.k_max, max_depth=self.config.data.max_depth)
        groups, mask = self.encoder.encode(X_num_trn)
        encoder_train_time = time.monotonic() - training_start_time
        
        if self.config.data.encoder == 'gmm':
            # adjust means and remove those not appearing in the data (after hard clustering)
            for i in range(groups.shape[1]):
                vals = groups[:,i].unique()
                self.encoder.means[i] = self.encoder.means[i][vals]
            
            # train additional ordinal encoder for groups
            # as some components may never be the argmax and thus not appear in the data
            self.gmm_ord_enc = OrdinalEncoder()
            groups = self.gmm_ord_enc.fit_transform(groups.numpy())
            groups = torch.from_numpy(groups).long()

        # DEBUG
        # group sizes
        # group_sizes = [len(m) for m in self.encoder.means]
        # np.array(group_sizes).mean()
            
        # test = []
        # for i in range(groups.shape[1]):
        #     val, counts = groups[:, i].unique(return_counts=True)
        #     test.append(val)
        # test
        
        # for v in test:
        #     assert len(v) == len(list(range(len(v))))
        
        # mean impute numerical values
        for i in range(X_num_trn.shape[1]):
            mean = torch.nanmean(X_num_trn[:,i])
            X_num_trn[:,i] = torch.nan_to_num(X_num_trn[:,i], nan=mean)
        
        # construct new data loader
        train_loader = FastTensorDataLoader(X_cat_trn_no_miss, X_num_trn, groups, mask,
                                            batch_size=self.config.data.batch_size, shuffle=True, drop_last=True)

        # determine n_classes for X_cat and Z_num
        n_classes_cat = []
        proportions_cat = []
        n_sample = X_cat_trn_no_miss.shape[0]
        for i in range(X_cat_trn_no_miss.shape[1]):
            val, counts = X_cat_trn_no_miss[:, i].unique(return_counts=True)
            n_classes_cat.append(len(val))
            proportions_cat.append(counts / n_sample)
        n_classes_num = []
        proportions_num = []
        for i in range(groups.shape[1]):
            val, counts = groups[:, i].unique(return_counts=True)
            n_classes_num.append(len(val))
            proportions_num.append(counts / n_sample)
            
        
        #################################################
        # setup low resolution model
        n_classes = n_classes_cat + n_classes_num
        proportions = proportions_cat + proportions_num
        
        cfg = self.config.lowres.model
        predictor = LowResMLP(n_classes, 
                            cfg.cat_emb_dim,
                            cfg.mlp_emb_dim,
                            cfg.mlp_n_layers,
                            cfg.mlp_n_units,
                            proportions, 
                            cfg.mlp_act,
                            )
        
        if cfg.variant == 'flow':
            self.lowres = CatFlow(predictor, 
                            n_classes, 
                            proportions,
                            cfg.cat_emb_dim, 
                            cfg.sigma_min,
                            cfg.sigma_max, 
                            cfg.sigma_data,
                            cfg.normalize_by_entropy,
                            cfg.timewarp_weight_low_noise,
                            cfg.timewarp_variant,
                            cat_emb_init_sigma=cfg.cat_emb_init_sigma,
                            learn_noise_schedule=cfg.learn_noise_schedule,
                            init_embs_zero=cfg.init_embs_zero,
                            learn_latents=cfg.learn_latents,
                            norm_dim=cfg.norm_dim,
                            )
        elif cfg.variant == 'cdtd':
            self.lowres = CatCDTD(predictor, 
                                n_classes, 
                                proportions,
                                cfg.cat_emb_dim, 
                                cfg.sigma_min,
                                cfg.sigma_max, 
                                cfg.sigma_data,
                                cfg.normalize_by_entropy,
                                cfg.timewarp_weight_low_noise,
                                cfg.timewarp_variant,
                                cfg.cat_emb_init_sigma,
                                )
        
        self.lowres.to(self.device)
        self.ema_lowres = ExponentialMovingAverage(self.lowres.parameters(), decay=self.config.lowres.training.ema_decay)
        
        opt_lowres = torch.optim.AdamW(self.lowres.parameters(),
                                       lr=self.config.lowres.training.lr, 
                                       weight_decay=self.config.lowres.training.weight_decay,
                                       betas=self.config.lowres.training.betas)
        
        if self.config.lowres.model.variant == 'flow':
            scheduler_lowres = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_lowres, mode='min', 
                                                                       factor=0.9, patience=3, 
                                                                       min_lr=1e-6)
        
        #################################################
        # setup high resolution model
        means = self.encoder.means
        stds = self.encoder.stds
        self.highres = HighResFlowModel(means, stds, n_classes_cat, n_classes_num,
                                        self.config.highres.model.mlp_emb_dim,
                                        self.config.highres.model.mlp_n_layers,
                                        self.config.highres.model.mlp_n_units,
                                        self.config.highres.model.gamma_input_dim,
                                        self.config.highres.model.cat_emb_dim)
        self.highres.to(self.device)
        num_params_lowres = sum(p.numel() for p in self.lowres.parameters())
        num_params_highres = sum(p.numel() for p in self.highres.parameters())
        print("Total parameters =", num_params_highres + num_params_lowres)
        print("Lowres parameters =", num_params_lowres)
        print("Highres parameters =", num_params_highres)
        
        self.ema_highres = ExponentialMovingAverage(self.highres.parameters(), decay=self.config.highres.training.ema_decay)
        
        opt_highres = torch.optim.AdamW(self.highres.parameters(),
                                    lr=self.config.highres.training.lr, 
                                    weight_decay=self.config.highres.training.weight_decay,
                                    betas=self.config.highres.training.betas)
        
        scheduler_highres = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_highres, mode='min', 
                                                                       factor=0.9, patience=3, 
                                                                       min_lr=1e-6)
        
        #################################################
        # train loop

        train_loader = cycle(train_loader)
        step = 0 
        n_inputs = 0
        lowres_loss_trn = 0
        lowres_loss_hist = []
        highres_loss_trn = 0
        highres_loss_hist = []
        emb_grad_norm_hist = []
        monitor_grad_norm = True if self.config.lowres.training.freeze_emb else False
        writer = SummaryWriter(self.workdir / 'tb')
        pbar = tqdm(total=self.config.lowres.training.num_steps_train)
        training_start_time = time.monotonic()
        
        while step < self.config.lowres.training.num_steps_train:
            
            # Linear warmup learning rate
            if step < self.config.lowres.training.num_steps_warmup:
                lr = self.config.lowres.training.lr * (step + 1) / self.config.lowres.training.num_steps_warmup
                for param_group in opt_lowres.param_groups:
                    param_group['lr'] = lr
            if (self.config.lowres.model.variant == 'cdtd') and (step > self.config.lowres.training.num_steps_warmup):
                aux_step = step - self.config.lowres.training.num_steps_warmup
                rate = 1 - (aux_step / (self.config.lowres.training.num_steps_train - self.config.lowres.training.num_steps_warmup))
                lr = self.config.lowres.training.lr * rate + 1e-6 * (1 - rate)
                for param_group in opt_lowres.param_groups:
                    param_group['lr'] = lr     
                    
            if step < self.config.highres.training.num_steps_warmup:
                lr = self.config.highres.training.lr * (step + 1) / self.config.highres.training.num_steps_warmup
                for param_group in opt_lowres.param_groups:
                    param_group['lr'] = lr
            
            opt_lowres.zero_grad(set_to_none=True)
            opt_highres.zero_grad(set_to_none=True)
            
            batch = next(train_loader)
            x_cat, x_num, z_num, mask = (x.to(self.device) for x in batch)
            B = len(x_cat)
            n_inputs += B
            
            ################################
            # lowres model
            
            losses = self.lowres.loss_fn(torch.column_stack((x_cat, z_num)))
            train_loss_lowres = losses['train_loss']
            train_loss_lowres.backward()
            
            # collect grad norm of embeddings
            if monitor_grad_norm:
                emb_grad_norm = 0
                emb_grad_norm += self.lowres.encoder.cat_emb.weight.grad.norm().item() ** 2
                emb_grad_norm += self.lowres.encoder.cat_bias.grad.norm().item() ** 2
                emb_grad_norm = emb_grad_norm ** (1. / 2)
                # emb_grad_norm_hist.append(emb_grad_norm)
            
                if (step > 0.1 * self.config.lowres.training.num_steps_train) and (emb_grad_norm < 0.01):
                    freeze(self.lowres, part='embedding')
                    print(f"Freezing embeddings at step {step}...")
                    monitor_grad_norm = False
                
            # update parameters
            if self.config.lowres.training.clip_grad:
                torch.nn.utils.clip_grad_norm_(self.lowres.parameters(), max_norm=1.0)
            opt_lowres.step()
            self.ema_lowres.update()
            # self.model.timewarp.update_ema()
            lowres_loss_trn += train_loss_lowres.detach().item() * B
            
            ################################
            # highres model
            
            train_loss_highres = self.highres.loss_fn(x_num, x_cat, z_num, mask)
            train_loss_highres.backward()
            
            if self.config.highres.training.clip_grad:
                torch.nn.utils.clip_grad_norm_(self.highres.parameters(), max_norm=1.0)
            opt_highres.step()
            self.ema_highres.update()
            highres_loss_trn += train_loss_highres.detach().item() * B
            
            ################################
            # bookkeeping and learning rate scheduling

            if step % self.config.lowres.training.log_steps == 0:
                
                lowres_loss_trn = lowres_loss_trn / n_inputs
                lowres_loss_hist.append(lowres_loss_trn)
                
                if self.config.lowres.model.variant == 'flow':
                    scheduler_lowres.step(lowres_loss_trn)
                
                highres_loss_trn = highres_loss_trn / n_inputs
                highres_loss_hist.append(highres_loss_trn)
                scheduler_highres.step(highres_loss_trn)
                
                train_dict = {'lowres': lowres_loss_trn, 'highres': highres_loss_trn}
                pbar.set_postfix({"loss (lowres)": f"{lowres_loss_trn:.4f}",
                                  "loss (highres)": f"{highres_loss_trn:.4f}"})
                for metric_name, metric_value in train_dict.items():
                    writer.add_scalar('losses/{}'.format(metric_name), metric_value, global_step=step)
                lowres_loss_trn = 0
                highres_loss_trn = 0
                n_inputs = 0
            step += 1
            pbar.update(1)
        pbar.close()
        training_duration = time.monotonic() - training_start_time
        with open(self.workdir / 'loss_hist_lowres.pkl', 'wb') as f:
            pickle.dump(lowres_loss_hist, f)
        with open(self.workdir / 'loss_hist_highres.pkl', 'wb') as f:
            pickle.dump(highres_loss_hist, f)
        

        # copy EMA weights to the model
        self.ema_lowres.copy_to()
        self.lowres.eval()
        self.ema_highres.copy_to()
        self.highres.eval()
               
        if save_model: 
            self.save_train_time(training_duration + encoder_train_time)
            self.save_model()
            
        
    def save_model(self):
        torch.save(self.lowres.state_dict(), os.path.join(self.workdir, "lowres.pt"))
        torch.save(self.highres.state_dict(), os.path.join(self.workdir, "highres.pt"))
        # cannot currently save the DT encoder due to using an R function for disttree
        if self.config.data.encoder == 'gmm':
            torch.save(self.encoder, os.path.join(self.workdir, "encoder.pt"))
            torch.save(self.gmm_ord_enc, os.path.join(self.workdir, "gmm_ord_enc.pt"))
        
    
    def load_model(self):
        
        set_seeds(self.seed)
        
        # additional data pre-processing to get z_num
        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=False)
        X_cat_trn = train_loader.data[0]
        X_num_trn = train_loader.data[1]
        
        # remove missingness indicator (used for other models) from cat features
        X_cat_trn_no_miss = X_cat_trn[:, :len(self.data_processor.cat_cols)]
        
        # encode X_num into Z_num
        if self.config.data.encoder == 'gmm':
            self.encoder = torch.load(os.path.join(self.workdir, "encoder.pt"), weights_only=False)
            self.gmm_ord_enc = torch.load(os.path.join(self.workdir, "gmm_ord_enc.pt"), weights_only=False)
        else:
            self.encoder = Discretizer(X_num_trn, variant=self.config.data.encoder, seed=self.seed,
                              k_max=self.config.data.k_max, max_depth=self.config.data.max_depth)
        groups, _ = self.encoder.encode(X_num_trn)
        
        if self.config.data.encoder == 'gmm':
            groups = self.gmm_ord_enc.transform(groups)
            groups = torch.from_numpy(groups).long()
        
        #########################
        # low resolution model
        
        # determine n_classes for X_cat and Z_num
        n_classes_cat = []
        proportions_cat = []
        n_sample = X_cat_trn_no_miss.shape[0]
        for i in range(X_cat_trn_no_miss.shape[1]):
            val, counts = X_cat_trn_no_miss[:, i].unique(return_counts=True)
            n_classes_cat.append(len(val))
            proportions_cat.append(counts / n_sample)
        n_classes_num = []
        proportions_num = []
        for i in range(groups.shape[1]):
            val, counts = groups[:, i].unique(return_counts=True)
            n_classes_num.append(len(val))
            proportions_num.append(counts / n_sample)
            
        n_classes = n_classes_cat + n_classes_num
        proportions = proportions_cat + proportions_num
        
        cfg = self.config.lowres.model
        predictor = LowResMLP(n_classes, 
                            cfg.cat_emb_dim,
                            cfg.mlp_emb_dim,
                            cfg.mlp_n_layers,
                            cfg.mlp_n_units,
                            proportions, 
                            cfg.mlp_act,
                            )
        
        if cfg.variant == 'flow':
            self.lowres = CatFlow(predictor, 
                            n_classes, 
                            proportions,
                            cfg.cat_emb_dim, 
                            cfg.sigma_min,
                            cfg.sigma_max, 
                            cfg.sigma_data,
                            cfg.normalize_by_entropy,
                            cfg.timewarp_weight_low_noise,
                            cfg.timewarp_variant,
                            cat_emb_init_sigma=cfg.cat_emb_init_sigma,
                            learn_noise_schedule=cfg.learn_noise_schedule,
                            init_embs_zero=cfg.init_embs_zero,
                            learn_latents=cfg.learn_latents,
                            norm_dim=cfg.norm_dim,
                            )
        elif cfg.variant == 'cdtd':
            self.lowres = CatCDTD(predictor, 
                                n_classes, 
                                proportions,
                                cfg.cat_emb_dim, 
                                cfg.sigma_min,
                                cfg.sigma_max, 
                                cfg.sigma_data,
                                cfg.normalize_by_entropy,
                                cfg.timewarp_weight_low_noise,
                                cfg.timewarp_variant,
                                cfg.cat_emb_init_sigma,
                                )
        checkpoint = torch.load(os.path.join(self.workdir, "lowres.pt"))
        self.lowres.load_state_dict(checkpoint)
        self.lowres.to(self.device)
        self.lowres.eval()
        
        #################################################
        # high resolution model
        means = self.encoder.means
        stds = self.encoder.stds
        self.highres = HighResFlowModel(means, stds, n_classes_cat, n_classes_num,
                                        self.config.highres.model.mlp_emb_dim,
                                        self.config.highres.model.mlp_n_layers,
                                        self.config.highres.model.mlp_n_units,
                                        self.config.highres.model.gamma_input_dim,
                                        self.config.highres.model.cat_emb_dim)
        checkpoint = torch.load(os.path.join(self.workdir, "highres.pt"))
        self.highres.load_state_dict(checkpoint)
        self.highres.to(self.device)
        self.highres.eval()
        
        # save data for plotting of learned noise schedule
        x_cat = X_cat_trn_no_miss.to(self.device)
        z_num = groups.to(self.device)
        t_grid, g = self.highres.plot_gamma(x_cat, z_num)
        with open(self.workdir / 'g_mean.pkl', 'wb') as f:
            pickle.dump(g.mean(1), f)
        with open(self.workdir / 'g_var.pkl', 'wb') as f:
            pickle.dump(g.var(1), f)
        with open(self.workdir / "g_t_grid.pkl", 'wb') as f:
            pickle.dump(t_grid, f)
        del x_cat
        del z_num
        
            
    def sample(self, num_samples, seed):

        # first sample low resolution information
        X_low_gen = self.lowres.sample_data(num_samples, num_steps=self.config.lowres.model.generation_steps,
                                           batch_size=self.config.lowres.model.generation_batch_size, seed=seed,
                                           verbose=False)
        X_cat_gen = X_low_gen[:, :len(self.data_processor.cat_cols)]
        Z_num_gen = X_low_gen[:, len(self.data_processor.cat_cols):]
        
        # then sample high resolution information conditioned on low resolution
        X_num_gen = self.highres.sample_data(X_cat_gen, Z_num_gen,
                                            num_steps=self.config.highres.model.generation_steps,
                                            batch_size=self.config.highres.model.generation_batch_size,
                                            seed=seed, verbose=False)
        
        # overwrite inflated / missing values in X_num using Z_num
        assert X_num_gen.shape == Z_num_gen.shape
        if self.config.data.encoder == 'gmm':
            Z_num_gen_enc = self.gmm_ord_enc.inverse_transform(Z_num_gen)
            Z_num_gen_enc = torch.from_numpy(Z_num_gen_enc).long()
        else:
            Z_num_gen_enc = Z_num_gen
        infl_mask, miss_mask = self.encoder.get_masks(Z_num_gen_enc)
        
        # get groups means (= inflated value if var = 0)
        Z_num_gen_means = self.highres.get_group_means(Z_num_gen.to(self.device) + self.highres.group_offset).squeeze(-1).cpu()
        X_num_gen = torch.where(infl_mask, Z_num_gen_means, X_num_gen)
        
        if miss_mask is not None:
            X_num_gen = torch.masked_fill(X_num_gen, miss_mask, torch.nan)
        
        return X_cat_gen.numpy(), X_num_gen.numpy()
