import glob
import os
import time

import numpy as np
import torch

from data.fast_dataloader import FastTensorDataLoader
from experiments.models.tabbyflow.ef_vfm.models.flow_model import ExpVFM
from experiments.models.tabbyflow.ef_vfm.modules.main_modules import UniModMLP
from experiments.models.tabbyflow.ef_vfm.trainer import Trainer
from utils import set_seeds

from .experiment import Experiment


class Experiment_TabbyFlow(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        set_seeds(self.seed, cuda_deterministic=True)

        d_numerical = len(self.data_processor.num_cols)
        train_loader, _ = self.data_processor.get_data_loaders()
        categories = self.data_processor.X_cat_n_classes

        # bring dataloader into correct format
        x = torch.cat((train_loader.data[1], train_loader.data[0]), dim=1)
        train_loader = FastTensorDataLoader(x, batch_size=train_loader.batch_size, shuffle=True, drop_last=True)

        model = UniModMLP(
            d_numerical,
            categories,
            self.config.model.n_layers,
            self.config.model.n_units,
            self.config.model.d_token,
            n_head=self.config.model.n_head,
            factor=self.config.model.factor,
            bias=self.config.model.bias,
            dim_t=self.config.model.dim_t,
            use_mlp=self.config.model.use_mlp,
        )
        model.to(self.device)
        flow_model = ExpVFM(
            num_classes=np.array(categories),
            num_numerical_features=d_numerical,
            vf_fn=model,
            device=self.device,
        )
        flow_model.to(self.device)
        flow_model.train()
        num_params = sum(p.numel() for p in flow_model.parameters())
        print("The number of parameters = ", num_params)

        # translate training steps into epochs
        steps_per_epoch = train_loader.dataset_len / train_loader.batch_size
        num_epochs = round(self.config.train.train_steps / steps_per_epoch)

        trainer = Trainer(
            flow_model,
            train_loader,
            None,
            None,
            None,
            None,
            self.config.train.lr,
            self.config.train.weight_decay,
            num_epochs,
            self.config.train.batch_size,
            check_val_every=num_epochs + 1,
            sample_batch_size=self.config.train.batch_size,
            num_samples_to_generate=self.num_samples,
            model_save_path=self.workdir,
            result_save_path=self.workdir,
            device=self.device,
            ckpt_path=None,
        )

        train_start_time = time.monotonic()
        trainer.run_loop()
        training_duration = time.monotonic() - train_start_time
        self.save_train_time(training_duration)

    @torch.inference_mode()
    def sample(self, num_samples, seed):
        if seed:
            set_seeds(seed, cuda_deterministic=True)

        with torch.no_grad():
            syn_data = self.flow_model.sample_all(num_samples, batch_size=self.config.train.batch_size)
        x_cont_gen = syn_data[:, : self.flow_model.num_numerical_features]
        x_cat_gen = syn_data[:, self.flow_model.num_numerical_features :]
        x_cont_gen = x_cont_gen.cpu().numpy()
        x_cat_gen = x_cat_gen.cpu().numpy().astype(int)

        return x_cat_gen, x_cont_gen

    def save_model(self):
        return

    def load_model(self):
        set_seeds(self.seed, cuda_deterministic=True)

        d_numerical = len(self.data_processor.num_cols)
        _, _ = self.data_processor.get_data_loaders()
        categories = self.data_processor.X_cat_n_classes

        model = UniModMLP(
            d_numerical,
            categories,
            self.config.model.n_layers,
            self.config.model.n_units,
            self.config.model.d_token,
            n_head=self.config.model.n_head,
            factor=self.config.model.factor,
            bias=self.config.model.bias,
            dim_t=self.config.model.dim_t,
            use_mlp=self.config.model.use_mlp,
        )
        model.to(self.device)
        self.flow_model = ExpVFM(
            num_classes=np.array(categories),
            num_numerical_features=d_numerical,
            vf_fn=model,
            device=self.device,
        )
        state_dicts = torch.load(glob.glob(os.path.join(self.workdir, "best_ema_model_*"))[0])
        self.flow_model._vf_fn.load_state_dict(state_dicts["vf_fn"])
        self.flow_model.to(self.device)
        self.flow_model.eval()
