import os
from collections import defaultdict
from itertools import product
import json
from typing import Type, List
from models.abstract_model import AbstractModel
from counterfactual_explanations.input_properties import InputProperties
from counterfactual_explanations.cf_generator import CounterfactualGenerator
import numpy as np
from pathlib import Path

class ModelFactory:
    def __init__(self, Model: Type[AbstractModel], input_properties: InputProperties, config: dict, config_multi: dict):
        self.Model = Model
        self.input_properties = input_properties
        self.config = config
        self.config_multi = config_multi

    def train_models(self, X_train: np.array, y_train: np.array, n_repeats: int, save_dir: Path, use_pretrained: bool=True):
        os.makedirs(save_dir, exist_ok=True)
        models = []
        models_over_repeat = defaultdict(list)
        random_state = 0

        for vals in product(*self.config_multi.values()):
            config = dict(zip(self.config_multi.keys(), vals))
            config = config | self.config 
            config_str = json.dumps(config, separators=(',', ':'))
            model_path_name = save_dir / f"{self.Model.__name__}_{config_str}"

            for repeat in range(n_repeats):
                repeat_path = model_path_name / f"repeat{repeat}"
                model = self.Model(config | {"random_state": random_state}, self.input_properties)
                model.load_or_train(repeat_path, X_train, y_train, use_pretrained)
                models_over_repeat[model_path_name].append(model)
                models.append(model)
                
                random_state += 1

        self.models = models
        self.models_over_repeat = models_over_repeat

        return models
        
    def get_models(self):
        return self.models
    
    def get_models_over_repeats(self):
        return self.models_over_repeat
    

# class ExternalModelFactory:
#     def __init__(self, models):
#         self.models = models

#     def train_models(self, X_train, y_train, n_repeats, save_dir, use_pretrained=True):
#         return self.models
    
#     def get_models(self):
#         return self.models
    
#     def get_models_over_repeats(self):
#         models_over_repeat = []
#         for model in self.models:
#             models_over_repeat.append([model])
#         return models_over_repeat

class GeneratorFactory:
    def __init__(self, generators_classes: List[Type[CounterfactualGenerator]], config: dict, config_multi: dict):
        self.generators_classes = generators_classes
        self.config = config
        self.config_multi = config_multi

        for config_key, config_val in self.config_multi.items():
            if isinstance(config_val, dict):
                self.config_multi[config_key] = [dict(zip(config_val.keys(), val)) for val in product(*config_val.values())]

    def setup_generators(self, model: AbstractModel, input_properties: InputProperties, X_train: np.array, y_train: np.array, X_calib: np.array, y_calib: np.array, save_dir: Path, use_pretrained: bool=True):
        generators = []

        for vals in product(*self.config_multi.values()):
            config = dict(zip(self.config_multi.keys(), vals))
            config = config | self.config 
            
            for generator_cls in self.generators_classes:
                generator = generator_cls(model, input_properties, config, save_dir, use_pretrained)
                generator.setup(X_train, y_train, X_calib, y_calib)
                generators.append(generator)

        self.generators = generators
        return self.generators

    def get_generators(self):
        return self.generators