import logging
import os
import pickle
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path

import torch
from omegaconf import OmegaConf

from data.data_prepm import DataWrangler
import json


class Experiment(ABC):
    def __init__(self, data_path, sample_path, config, exp_path, dataset, device, preproc, run_name=None, beta='0p7', use_log='switch'):
        self.common_config = OmegaConf.load("configs/common_config.yaml")
        self.device = device
        self.seed = self.common_config.seed
        self.config = config
        self.data_path = data_path
        self.sample_path = sample_path
        self.preproc = preproc
        self.beta = beta
        self.use_log = use_log
        self.dataset = dataset
        if run_name is None or run_name == "":
            run_name = f"{dataset}_{exp_path or datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.run_name = run_name
        self.run_dir = os.path.join(Path(__file__).parent.parent, "results", run_name)
        self.logdir = self.run_dir
        self.ckpt_restore_dir = self.run_dir
        self.config_dir = os.path.join("configs", getattr(self.config, "model_name", "cdtd"))
        self.create_folders()

        logging.warning(f"Initializing {dataset} dataset")
        self.data_wrangler = DataWrangler(
            self.dataset,
            self.data_path,
            self.sample_path,
            self.logdir,
            self.config,
            self.common_config.val_prop,
            self.common_config.test_prop,
            self.seed,
            self.preproc,
            beta=beta,
            use_log=use_log,
        )
        num_real_train_samples = self.data_wrangler.data.get_train_obs() + self.data_wrangler.data.get_val_obs()
        self.num_samples = min(num_real_train_samples, 50000)
        logging.info(f"Done init: {self.data_wrangler.data.get_total_obs()} obs, {self.data_wrangler.num_total_features} features")

    @abstractmethod
    def train(self, **kwargs): ...

    @abstractmethod
    def sample_tabular_data(self, num_samples, **kwargs): ...

    @abstractmethod
    def save_model(self): ...

    @abstractmethod
    def load_model(self): ...

    def evaluate_generative_model(self):
        """Load model, generate samples, save to sample_path. Evaluation (density, mle, quality) is done by evaluator.py (synthcity)."""
        logging.warning("Load generative model...")
        self.load_model()
        path = self.data_path
        info_path = os.path.join(path, "info.json")
        with open(info_path, "r") as f:
            info = json.load(f)
        sample_num = info["train_num"]
        X_cat_gen, X_cont_gen, y_gen = self.sample_tabular_data(sample_num, seed=42)
        self.data_wrangler.save_data(X_cat_gen, X_cont_gen, y_gen, sample_path=self.sample_path)
        logging.info(f"Saved {sample_num} samples to {self.sample_path}")

    def create_folders(self):
        if not os.path.exists(self.run_dir):
            os.makedirs(self.run_dir)

    def save_train_time(self, duration):
        """Save training time in minutes."""
        with open(os.path.join(self.logdir, "train_time.pkl"), "wb") as f:
            pickle.dump(duration / 60, f)

    def save_sample_time(self, duration):
        """Save sampling time in seconds."""
        with open(os.path.join(self.logdir, "sample_time.pkl"), "wb") as f:
            pickle.dump(duration, f)
