from .experiment import Experiment


# from .models.cdtd import CDTD
# from .utils import set_seeds

# from .models.lowres.layers import LowResMLP
# from .utils import set_seeds
# from .models.lowres.utils import FastTensorDataLoader, cycle



from experiments.models.lowres.layers import LowResMLP
from experiments.utils import set_seeds
from experiments.models.lowres.utils import FastTensorDataLoader, cycle
from experiments.models.lowres.encoder import Discretizer
from experiments.models.lowres.lowres_diff import CatCDTD, CatFlow

from torch_ema import ExponentialMovingAverage
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import time

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import polars as pl
import pickle
import logging


def get_lowres_model(cfg, X_cat_trn, model='cdtd'):

    proportions = []
    n_sample = X_cat_trn.shape[0]
    n_classes = []
    for i in range(X_cat_trn.shape[1]):
        val, counts = X_cat_trn[:, i].unique(return_counts=True)
        n_classes.append(len(val))
        proportions.append(counts / n_sample)
        
    score_model = LowResMLP(n_classes, 
                            cfg.cat_emb_dim,
                            cfg.mlp_emb_dim,
                            cfg.mlp_n_layers,
                            cfg.mlp_n_units,
                            proportions, 
                            cfg.mlp_act,
                            )

    if model == 'cdtd':
        model = CatCDTD(score_model, 
                        n_classes, 
                        proportions,
                        cfg.cat_emb_dim, 
                        cfg.sigma_min,
                        cfg.sigma_max, 
                        cfg.sigma_data,
                        cfg.normalize_by_entropy,
                        cfg.timewarp_weight_low_noise,
                        cfg.timewarp_variant,
                        cfg.cat_emb_init_sigma,
                        )
    elif model == 'flow':
        model = CatFlow(score_model, 
                        n_classes, 
                        proportions,
                        cfg.cat_emb_dim, 
                        cfg.sigma_min,
                        cfg.sigma_max, 
                        cfg.sigma_data,
                        cfg.normalize_by_entropy,
                        cfg.timewarp_weight_low_noise,
                        cfg.timewarp_variant,
                        cat_emb_init_sigma=cfg.cat_emb_init_sigma,
                        learn_noise_schedule=cfg.learn_noise_schedule,
                        init_embs_zero=cfg.init_embs_zero,
                        learn_latents=cfg.learn_latents,
                        norm_dim=cfg.norm_dim,
                        time_reweight=cfg.time_reweight,
                        )
    return model
        
        
        

class Experiment_LowRes(Experiment):
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        
        train_loader, val_loader, test_loader = self.data_processor.get_data_loaders(mean_impute=False, include_test=True)
        X_cat_trn = train_loader.data[0]
        X_num_trn = train_loader.data[1]
        
        
        # Remove NaNs from X_num_trn and compute feature-wise variance
        # X_num_trn_no_nan = X_num_trn[~torch.isnan(X_num_trn).any(dim=1)]
        # feature_variances = X_num_trn_no_nan.var(dim=0)
        # print("Feature-wise variances (no NaNs):", feature_variances)
        
        # debug NaNs in loss in beijing
        # df_trn, df_val, df_test = self.data_processor.get_data_splits()
        # df_trn['Is'].value_counts()
        # X_num_trn[:,8]
        # sns.kdeplot(X_num_trn[:,8].numpy(), bw_adjust=0.5)
        
        # get low resolution encodings for numerical features
        enc = Discretizer(X_num_trn, variant='dt', max_depth=5)
        groups, _  = enc.encode(X_num_trn)
        # remove missingness indicator from cat features
        X_trn = X_cat_trn[:, :len(self.data_processor.cat_cols)]
        # combine with Z_num
        X_trn = torch.column_stack((X_trn, groups))
        
        # for validation set
        groups, _ = enc.encode(val_loader.data[1])
        X_val = val_loader.data[0][:, :len(self.data_processor.cat_cols)]
        X_val = torch.column_stack((X_val, groups))
        
        # for test set
        groups, _ = enc.encode(test_loader.data[1])
        X_test = test_loader.data[0][:, :len(self.data_processor.cat_cols)]
        X_test = torch.column_stack((X_test, groups))
        
        train_loader = FastTensorDataLoader(X_trn, batch_size=self.config.training.batch_size, shuffle=True, drop_last=True)
        val_loader = FastTensorDataLoader(X_val, batch_size=val_loader.batch_size, shuffle=False)
        
        self.model = get_lowres_model(self.config.model, X_trn, model=self.config.model.variant)
        self.model.to(self.device)
        num_params = sum(p.numel() for p in self.model.parameters())
        print("Total parameters = ", num_params)
        self.ema = ExponentialMovingAverage(self.model.parameters(), decay=self.config.training.ema_decay)
        
        
        def warmup_lr(self):
            if self.current_step < self.num_steps_warmup:
                lr = self.cfg.lr * (self.current_step + 1) / self.num_steps_warmup
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = lr
                    
        def freeze(model, part='model'):
            
            if part == 'model':
                # Freeze all parameters except for CatEncoder
                for name, param in model.named_parameters():
                    if "encoder" not in name:
                        param.requires_grad = False
                    else:
                        param.requires_grad = True

            elif part == 'embedding':
                # Freeze embedding layers only
                for name, param in model.named_parameters():
                    if "encoder" in name:
                        param.requires_grad = False
                    else:
                        param.requires_grad = True
  
        train_loader = cycle(train_loader)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.training.lr, 
                                      weight_decay=self.config.training.weight_decay, betas=self.config.training.betas)
        if self.config.training.scheduler:
            if self.config.model.variant == 'flow':
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=3,
                                                               min_lr=1e-6)
        step = 0 
        patience = 0
        best_loss = float('inf')
        n_inputs = 0
        loss_trn = 0
        loss_val = 0
        loss_hist = []
        emb_grad_norm_hist = []
        monitor_grad_norm = True if self.config.training.freeze_emb else False
        writer = SummaryWriter(self.workdir / 'tb')
        pbar = tqdm(total=self.config.training.num_steps_train)
        training_start_time = time.monotonic()
        
        while step < self.config.training.num_steps_train:
            
            # Linear warmup learning rate
            if step < self.config.training.num_steps_warmup:
                lr = self.config.training.lr * (step + 1) / self.config.training.num_steps_warmup
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                    
            if (self.config.model.variant == 'cdtd') and (step > self.config.training.num_steps_warmup):
                aux_step = step - self.config.training.num_steps_warmup
                rate = 1 - (aux_step / (self.config.training.num_steps_train - self.config.training.num_steps_warmup))
                lr = self.config.training.lr * rate + 1e-6 * (1 - rate)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                
            optimizer.zero_grad(set_to_none=True)
            x = next(train_loader)[0].to(self.device)
            B = len(x)
            n_inputs += B
            losses = self.model.loss_fn(x)
            train_loss = losses['train_loss']
            train_loss.backward()
            
            # collect grad norm of embeddings
            if monitor_grad_norm:
                emb_grad_norm = 0
                emb_grad_norm += self.model.encoder.cat_emb.weight.grad.norm().item() ** 2
                emb_grad_norm += self.model.encoder.cat_bias.grad.norm().item() ** 2
                emb_grad_norm = emb_grad_norm ** (1. / 2)
                # emb_grad_norm_hist.append(emb_grad_norm)
            
                if (step > 0.1 * self.config.training.num_steps_train) and (emb_grad_norm < 0.001):
                    freeze(self.model, part='embedding')
                    print(f"Freezing embeddings at step {step}...")
                    monitor_grad_norm = False
                
            # update parameters
            if self.config.training.clip_grad:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer.step()
            self.ema.update()
            if self.config.model.variant == 'cdtd':
                self.model.timewarp.update_ema()

            loss_trn += losses['weighted'].detach().mean().item() * B

            if step % self.config.training.log_steps == 0:
                loss_trn = loss_trn / n_inputs
                loss_hist.append(loss_trn)
                
                # compute validation loss when training is sufficiently advanced
                # can be adjusted to save some training time
                if self.config.training.use_val and step > 0.1 * self.config.training.num_steps_train:
                    with self.ema.average_parameters():
                        # validation loss should be quickest / cheapest to compute
                        t_grid = torch.linspace(1e-6, 1-1e-6, 50, dtype=torch.float32).to(self.device)
                        total_val_loss = 0
                        for t in t_grid:
                            n_inputs_val = 0
                            loss_val = 0
                            for x in val_loader:
                                x = x[0].to(self.device)
                                B_val = len(x)
                                n_inputs_val += B_val
                                losses = self.model.loss_fn(x, t=t.repeat(B_val))
                                loss = losses['weighted']
                                loss_val += loss.detach().item() * B_val
                            loss_val = loss_val / n_inputs_val
                            total_val_loss += loss_val
                        loss_val = total_val_loss / len(t_grid)
                    
                        if self.config.training.schedule and self.config.model.variant == 'flow':
                            scheduler.step(loss_val)
                        
                        if loss_val < best_loss:
                            best_loss = loss_val
                            patience = 0
                            self.save_model() # within ema.average_parameters, so model has ema weights
                        else:
                            patience += 1
                            if patience >= self.config.training.patience:
                                print("Early stopping at step ", step)
                                break
                        
                elif self.config.training.scheduler and self.config.model.variant == 'flow':
                    # when not using validation set, use training loss for LR scheduler
                    scheduler.step(loss_trn)
        
                train_dict = {'train loss': loss_trn, 'val loss': loss_val}
                pbar.set_postfix({f"train loss (last {self.config.training.log_steps} steps)": f"{loss_trn:.4f}"})
                for metric_name, metric_value in train_dict.items():
                    writer.add_scalar('losses/{}'.format(metric_name), metric_value, global_step=step)
                loss_trn = 0
                n_inputs = 0
            step += 1
            pbar.update(1)
            
        pbar.close()
        training_duration = time.monotonic() - training_start_time
        with open(self.workdir / 'loss_hist.pkl', 'wb') as f:
            pickle.dump(loss_hist, f)
        
        if self.config.training.use_val:
            checkpoint = torch.load(os.path.join(self.workdir, "model.pt"))
            self.model.load_state_dict(checkpoint)
        else:
            # copy EMA weights to the model
            self.ema.copy_to()
        self.model.eval()
               
        if save_model: 
            self.save_train_time(training_duration)
            if not self.config.training.use_val:
                self.save_model()
            
    
        ########################################################################3
        # Simple evaluation
        
        from sdmetrics.reports.single_table import QualityReport
        from evaluation.eval_detection import DetectionScore
        from evaluation.eval_alphaprecision import AlphaPrecision
        from sklearn.preprocessing import OneHotEncoder
        from evaluation.eval_dcr import DCRScore
        

        df_trn = pd.DataFrame(X_trn)
        df_val = pd.DataFrame(X_val)
        df_tst = pd.DataFrame(X_test)
        cols = self.data_processor.cat_cols + self.data_processor.num_cols
        df_trn.columns = cols
        df_val.columns = cols
        df_tst.columns = cols
        df_trn_pl = pl.from_pandas(df_trn)
        # df_val_pl = pl.from_pandas(df_val)
        df_tst_pl = pl.from_pandas(df_tst)
        detect = DetectionScore(cols, [])
        alphaprec = AlphaPrecision(cat_cols=[])
        dcr = DCRScore(df_trn_pl, df_tst_pl, cols, [])
        
        eval_results = []
        for i in tqdm(range(self.common_config.eval_sample_iter)):
            results = {}
            seed = (42 + (i-1)) * i
            X_gen = self.model.sample_data(X_trn.shape[0], 
                                        num_steps=self.config.model.generation_steps, # 300 and 500 makes no difference, saturates at 300
                                        batch_size=self.config.model.generation_batch_size,
                                        seed=seed,
                                        verbose=False)
            df_gen = pd.DataFrame(X_gen)
            df_gen.columns = cols
            metadata = {}
            metadata['columns'] = {}
            for lab in cols:
                metadata['columns'][lab] = {'sdtype': 'categorical'}
            quality_report = QualityReport()
            quality_report.generate(df_trn, df_gen, metadata, verbose=False)
            quality = quality_report.get_properties()
            results['shape'] = quality['Score'][0].item()
            results['trend'] = quality['Score'][1].item()

            # detection score computation
            df_gen_pl = pl.from_pandas(df_gen)
            detect_results = detect.estimate_score(df_trn_pl, df_gen_pl, drop='num', seed=seed)
            results['detection_score'] = detect_results['detection_score']
            
            # prep data for precision / recall
            ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
            X_cat_trn = ohe.fit_transform(df_trn)
            X_cat_gen = ohe.transform(df_gen)
            results.update(alphaprec.compute_metrics(X_cat_trn, X_cat_gen))
            
            dcr_res = dcr.compute_dcr(df_gen_pl, seed=seed)
            del dcr_res['dcr_raw']
            results.update(dcr_res)
            eval_results.append(results)
            
        eval_results = {k: [d[k] for d in eval_results] for k in eval_results[0]}
        with open(self.workdir / 'eval_results.pkl', 'wb') as f:
            pickle.dump(eval_results, f)
        logging.info('=== Evaluation finished, results saved! ===')
        
        res = {k: np.array(v).mean().item() for k, v in eval_results.items()}
        for k, v in res.items():
            print(f'{k}: {v:.5f}')
        
    
    def save_model(self):
        # for CatFlow save full model, not just score model, as we need embeddings, timewarping, etc.
        torch.save(self.model.state_dict(), os.path.join(self.workdir, "model.pt"))
        
    
    def load_model(self):
        
        set_seeds(self.seed)
        train_loader, val_loader = self.data_processor.get_data_loaders(mean_impute=False)
        X_cat_trn = train_loader.data[0]
        X_num_trn = train_loader.data[1]
        
        # get low resolution encodings for numerical features
        enc = Discretizer(X_num_trn, variant='dt')
        groups, _,  = enc.encode(X_num_trn)

        # remove missingness indicator from cat features
        X_trn = X_cat_trn[:, :len(self.data_processor.cat_cols)]
        # combine with Z_num
        X_trn = torch.column_stack((X_trn, groups))
        
        self.model = get_lowres_model(self.config.model, X_trn, model=self.config.model.variant)
        checkpoint = torch.load(os.path.join(self.workdir, "model.pt"))
        self.model.load_state_dict(checkpoint)
        self.model.to(self.device)
        self.model.eval()

            
    def sample(self, num_samples, seed):
        X_cat_gen = self.model.sample_data(num_samples, num_steps=self.config.model.num_steps_sample,
                                           batch_size=self.config.model.sample_batch_size, seed=seed)
        return X_cat_gen, None  # no numerical features in low-res model