import numpy as np
import pandas as pd
from pathlib import Path
from itertools import product
import json
from tqdm import tqdm
import pickle

def combinations(grid):
    return list(dict(zip(grid.keys(), values)) for values in product(*grid.values()))
        
def get_hparams(experiment):
    if experiment not in globals():
        raise NotImplementedError
    return globals()[experiment]().get_hparams()    

def get_script_name(experiment):
    if experiment not in globals():
        raise NotImplementedError
    return globals()[experiment].fname


datasets = ['synthetic', 'celebA']
model_dir = Path('/home/CLUSTER/USER/results/expl_dist_shift/models')

class train_model():
    fname = 'train_model'
    def __init__(self):
        self.hparams1 = {
            'exp_name': ['train_model'],
            'dataset': ['celebA'],
            'model': ['lr'],
            'emb_model': ['resnet18']
        }

        self.hparams2 = {
            'exp_name': ['train_model'],
            'dataset': ['synthetic'],
            'model': ['lr', 'xgb']
        }

    def get_hparams(self):
        return combinations(self.hparams1) + combinations(self.hparams2) 

class explain():
    fname = 'explain'
    def __init__(self):        
        models = {i: [] for i in datasets}

        for i in model_dir.glob('**/done'):
            args = json.load((i.parent/'args.json').open('rb'))
            if args['dataset'] not in datasets:
                continue
            if args['exp_name'] == 'train_model':
                models[args['dataset']].append(str(i.parent))

        self.base_hparams = {
            'exp_name': ['explain'],
            'metric': ['acc',  'brier'],
            'weight_model': ['xgb']
        }

        self.synthetic_hparams = {
            'dataset': ['synthetic'],
            'spu_q': np.linspace(0, 1, 11),
            'spu_mu_add': np.arange(-1, 6, 1),
            'spu_y_noise': [0.25],
            'model_dir': models['synthetic'],
        }

        self.synthetic_hparams2 = {
            'dataset': ['synthetic'],
            'spu_q': np.linspace(0, 1, 11),
            'spu_mu_add': [3.],
            'spu_y_noise': [0.25],
            'spu_x1_weight': [0.],
            'model_dir': models['synthetic'],
        }

        # oversample_grid = np.concatenate((np.arange(0.02, 0.21, 0.02), np.arange(0.3, 1, 0.1)))
#        oversample_grid = np.arange(0.1, 1, 0.1)
#        self.waterbirds_hparams = {
#            'dataset': ['waterbirds'],
#            'oversample_ratio': oversample_grid,
#            'model_dir': models['waterbirds'],
#            'calibrate_weight_models': [True, False]
#        }

        
        celebA_target_dirs = []
        for i in (model_dir.parent/'data'/'celebA').glob('**/labels.csv'):
            if i.parent.name != 'base':
                celebA_target_dirs.append(str(i.parent.absolute()))

        self.celebA_hparams = {
            'dataset': ['celebA'],
            'model_dir': models['celebA'],
            'target_data_dir': celebA_target_dirs,
            'shapley_method': ['EXACT']
        }

        self.no_clipping = {
            'clip_probs': [False]
        }

        self.clip_probs = {
            'clip_probs': [True],
            'clip_weights': [False],
            'clip_prob_thres': [0.95, 0.99]
        }
            

#        self.metashift_hparams = {
#            'dataset': ['metashift'],
#            'oversample_ratio': oversample_grid,
#            'model_dir': models['metashift'],     
#            'calibrate_weight_models': [True, False]       
#        }

    def get_hparams(self):
        return (
        combinations({**self.base_hparams, **self.synthetic_hparams, **self.no_clipping}) + 
        combinations({**self.base_hparams, **self.synthetic_hparams2, **self.no_clipping}) + 
        combinations({**self.base_hparams, **self.synthetic_hparams, **self.clip_probs}) + 
        combinations({**self.base_hparams, **self.synthetic_hparams2, **self.clip_probs}) + 
        combinations({**self.base_hparams, **self.celebA_hparams}) 
        )


class explain_janzing():
    fname = 'explain_janzing'
    def __init__(self):        
        models = {i: [] for i in datasets}

        for i in model_dir.glob('**/done'):
            args = json.load((i.parent/'args.json').open('rb'))
            if args['dataset'] not in datasets:
                continue
            if args['exp_name'] == 'train_model':
                models[args['dataset']].append(str(i.parent))

        self.base_hparams = {
            'exp_name': ['explain_janzing'],
            'weight_model': ['xgb']
        }

        self.synthetic_hparams = {
            'dataset': ['synthetic'],
            'spu_q': np.linspace(0, 1, 11),
            'spu_mu_add': np.arange(-1, 6, 1),
            'spu_y_noise': [0.25],
            'model_dir': [models['synthetic'][0]],
        }

        self.synthetic_hparams2 = {
            'dataset': ['synthetic'],
            'spu_q': np.linspace(0, 1, 11),
            'spu_mu_add': [3.],
            'spu_y_noise': [0.25],
            'spu_x1_weight': [0.],
            'model_dir': [models['synthetic'][0]],
        }


    def get_hparams(self):
        return (
        combinations({**self.base_hparams, **self.synthetic_hparams}) + 
        combinations({**self.base_hparams, **self.synthetic_hparams2}) 
        )


class explain_test_hparams():
    fname = 'explain'
    def __init__(self):        
        models = {i: [] for i in datasets}

        for i in model_dir.glob('**/done'):
            args = json.load((i.parent/'args.json').open('rb'))
            if args['dataset'] not in datasets:
                continue
            if args['exp_name'] == 'train_model':
                models[args['dataset']].append(str(i.parent))

        self.base_hparams = {
            'exp_name': ['explain_test_hparams'],
            'metric': ['brier'],
            'weight_model': ['xgb']
        }

        oversample_grid = np.arange(0.1, 0.6, 0.1)

        self.metashift_hparams = {
            'dataset': ['metashift'],
            'oversample_ratio': oversample_grid,
            'model_dir': [models['metashift'][0]],     
            'calibrate_weight_models': [True, False],
            'imp_weight_type': ['normal', 'self_normalize'],
            'shapley_method': ['EARLY_STOPPING', 'EXACT']
        }

        self.no_clip = {
            'clip_weights': [False],
            'clip_probs': [False]
        }

        self.clip_weights = {
            'clip_weights': [True],
            'clip_probs': [False],
            'clip_weight_thres': [10, 100]
        }

        self.clip_probs = {
            'clip_probs': [True],
            'clip_weights': [False],
            'clip_prob_thres': [0.95, 0.99]
        }

    def get_hparams(self):
        return (
        combinations({**self.base_hparams, **self.metashift_hparams, **self.no_clip}) +
         combinations({**self.base_hparams, **self.metashift_hparams, **self.clip_weights}) +
          combinations({**self.base_hparams, **self.metashift_hparams, **self.clip_probs})  
        )
