
from sklearn import datasets
from models.cara import CARA, CARAOracle

from src.models.adwin import ADWIN
from src.models.fhddm import FHDDM
from src.models.kswin import KSWIN
from src.models.oracle import Oracle
from src.models.performance_forecaster import PerformanceForecaster
from src.models.pilot import Pilot
from src.retraining_task_processing.synthetic_dataset import load_synthetic_dataset
from src.seq_dataset import SeqDataset
from src.retraining_task_processing.temporal_dataset import load_temporal_dataset

from retraining_task_processing.get_sequence import get_sequence
from src.models.base_models import (
    Random,
    FixedRetrain
)
from sklearn import datasets


def make_model(cfg_model, cfg_data, cfg_exp, trial_index=None):
    data_in_dataloader = True if cfg_data['name'] in ['wild'] else False
    if cfg_model["name"] == "random":
        model = Random(cfg_data["T"], cfg_data["offline_t"])
        return model

    if cfg_model["name"] == "gp":
        if 'features' not in cfg_model:
            features = []
        else:
            features = cfg_model['features'].split(' ')
        if cfg_model['variant'] == 'mean':
            features = ['time_since_train']
        
        unc = cfg_model['unc'] if 'unc' in cfg_model else False
        var = cfg_model['var'] if 'var' in cfg_model else None
        loss_is_01 = cfg_model['loss_is_01'] if 'loss_is_01' in cfg_model else False
        model = PerformanceForecaster(cfg_data["T"], cfg_data["offline_t"], cfg_exp["relative_pe"],
                                      variant=cfg_model['variant'], loss_is_01=loss_is_01, features=features, unc=unc, var=var)
        return model

    if cfg_model["name"] == "adwin":
        
        model = ADWIN(cfg_data["T"], cfg_data["offline_t"], cfg_model['significance'], data_in_dataloader=data_in_dataloader)
        return model
    
    if cfg_model["name"] == "fhddm":
        model = FHDDM(cfg_data["T"], cfg_data["offline_t"], cfg_model['significance'], data_in_dataloader=data_in_dataloader)
        return model
    if cfg_model["name"] == "kswin":
        model = KSWIN(cfg_data["T"], cfg_data["offline_t"], cfg_model['significance'], data_in_dataloader=data_in_dataloader)
        return model
    
    if cfg_model["name"] == "fixed_retrain":
        model = FixedRetrain(
            cfg_data["T"], cfg_data["offline_t"], cfg_model["schedule"])
        return model

    if cfg_model["name"] == "cara":
        if cfg_data['name'] == 'wild': # for wild we use presaved C matrix 
            model = CARAWild(cfg_data["T"], cfg_data["offline_t"],
                     cfg_exp["relative_pe"], variant=cfg_model['variant'], cara_store_dir=cfg_data['cara_store_dir'], trial_index=trial_index)
        else:
            model = CARA(cfg_data["T"], cfg_data["offline_t"],
                        cfg_exp["relative_pe"], variant=cfg_model['variant'])
        return model

    if cfg_model["name"] == "cara_oracle": 
        model = CARAOracle(
            cfg_data["T"], cfg_data["offline_t"], cfg_model["relative_pe"])
        return model
    if cfg_model["name"] == "pilot":
        model = Pilot(cfg_data["T"], cfg_data["offline_t"], discounted_cost=cfg_model["discounted_cost"])
        return model

    if cfg_model["name"] == "oracle":
        model = Oracle(
            cfg_data["T"], cfg_data["offline_t"], cfg_model["relative_pe"])
        return model

    print(f"ERROR, DIDN'T FIND MODEL {cfg_model['name']}")


def make_dataset(cfg_dataset: dict, cfg_experiment: dict, seed: int) -> SeqDataset:
    dataset_name = cfg_dataset["name"]

    if dataset_name == 'iris':
        iris = datasets.load_iris()
        X = iris.datas
        y = iris.target

        offline_seq, online_seq, offline_queries, online_queries = get_sequence(
            X, y, seed, cfg_dataset, cfg_experiment)

    elif dataset_name in ['circles', 'covcon', 'gauss']:
        SeqDataset = load_synthetic_dataset(cfg_dataset, seed, dataset_name)

    elif dataset_name in ['electricity', 'airplanes', 'yelp', 'epicgames']:
        SeqDataset = load_temporal_dataset(cfg_dataset, seed, dataset_name)
   
    return SeqDataset


def make_reporter(cfg_reporter: dict):
    return None
