import os
import pathlib
import argparse
from time import time
import torch
import torch.nn as nn
from fastshap import FastSHAP

from global_variables import RANDOM_SEED
from models import random_forest_model, catboost_model, nn_model
import utils
from utils import MarginalImputerTorch


def prepare_data(dataset):
    x_train, x_test, _, _ = utils.get_dataset(dataset, with_splits=True)
    return x_train, x_test

def get_model(in_features, size, device='cpu'):
    if size == 'small':
        return nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, in_features)).to(device)
    
    elif size == 'medium':
        return nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, in_features)).to(device)
    
    elif size ==  'large':
        return nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, in_features)).to(device)

def train_model(explainer, imputer, x_train, x_test, fastshap_n_samples, device='cpu'):
    # Set up FastSHAP wrapper
    fastshap = FastSHAP(explainer, imputer, normalization='additive')

    # Train
    fastshap.train(
        x_train,
        x_test,
        batch_size=64,
        num_samples=fastshap_n_samples,
        max_epochs=200,
        eff_lambda=0,
        paired_sampling=True,
        validation_samples=128,
        validation_seed=123,
        verbose=True)

    # Print performance
    print('Best val loss = {:.8f}'.format(min(fastshap.loss_list)))
    return fastshap

def get_checkpoint_dir(dataset, model_name, size, fastshap_n_samples, no_background_sample):
    this_directory = pathlib.Path(__file__).parent.resolve()
    model_dir = f"{this_directory}/fastshap_cache/{dataset}"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    checkpoint_file = f"{model_dir}/{model_name}_{size}_fs={fastshap_n_samples}_bs={no_background_sample}.pth"

    return checkpoint_file

def save_fastshap(fastshap, dataset, model_name, size, fastshap_n_samples, no_background_sample, training_duration):
    checkpoint = { 
            'model': fastshap.explainer.state_dict(),
            'loss_list': fastshap.loss_list,
            'size': size,
            'fastshap_n_samples': fastshap_n_samples,
            'training_duration': training_duration,
        }
    
    torch.save(checkpoint, get_checkpoint_dir(dataset, model_name, size, fastshap_n_samples, no_background_sample))

def load_fastshap_explainer(dataset, model_name, size, fastshap_n_samples, no_background_sample, device='cpu'):
    checkpoint_file = get_checkpoint_dir(dataset, model_name, size, fastshap_n_samples, no_background_sample)

    dataset_settings = utils.get_task_settings()
    in_features = dataset_settings["no_features"][dataset]
    
    if os.path.exists(checkpoint_file):
        checkpoint = torch.load(checkpoint_file, map_location=torch.device(device))
        model = get_model(in_features=in_features, size=size, device=device)
        model.load_state_dict(checkpoint["model"])
        
    else:
        raise Exception(f"Could not find {checkpoint_file} to load from.")

    return model

def load_model(dataset, model, depth="", device='cpu'):
    if model == "nn":
        model = nn_model.load_model(dataset, best=True, device=device)

    elif model == "random_forest":
        model = random_forest_model.load_model(dataset, depth)

    elif model == "catboost":
        model = catboost_model.load_model(dataset, depth)

    return model

def get_task_name(model_config):
    if model_config["model"] == "nn":
        return model_config["model"]
    else:
        return f'{model_config["model"]}_{model_config["depth"]}'

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default="cpu")
    args = parser.parse_args()
    device = args.device
    fastshap_n_samples = [1, 4, 16]
    sizes = ['small', 'medium', 'large']
    tasks = [
        # ("entacmaea", {"model": "nn"}),
        ("sgemm", {"model": "nn"}),
        # ("gb1",  {"model": "random_forest", "depth": "3"}),
        # ("gb1",  {"model": "random_forest", "depth": "4"}),
        # ("gb1",  {"model": "random_forest", "depth": "5"}),
        # ("gb1",  {"model": "random_forest", "depth": "6"}),
        # ("gb1",  {"model": "random_forest", "depth": "7"}),
        # ("gb1",  {"model": "random_forest", "depth": "8"}),
        # ("gb1",  {"model": "random_forest", "depth": "9"}),
        # ("gb1",  {"model": "random_forest", "depth": "10"}),
        # ("avGFP", {"model": "catboost", "depth": "1"}),
        # ("avGFP", {"model": "catboost", "depth": "2"}),
        # ("avGFP", {"model": "catboost", "depth": "3"}),
        # ("avGFP", {"model": "catboost", "depth": "4"}),
        # ("avGFP", {"model": "catboost", "depth": "5"}),
        # ("avGFP", {"model": "catboost", "depth": "6"}),
        # ("avGFP", {"model": "catboost", "depth": "7"}),
        # ("avGFP", {"model": "catboost", "depth": "8"}),
    ]
    dataset_settings = utils.get_task_settings()

    for fastshap_n_sample in fastshap_n_samples:
        for size in sizes:
            for dataset, model_config in tasks:
                background_samples = dataset_settings["background_samples"][dataset]
                test_samples = dataset_settings["test_samples"][dataset]

                # Prepare inputs (labels are generated off of the predictor function)
                x_train, x_test = prepare_data(dataset)
                x_test = x_test[0:test_samples[-1]]

                for no_background_sample in background_samples:
                    # Check if the model is already available
                    checkpoint_file = get_checkpoint_dir(dataset, get_task_name(model_config), size, fastshap_n_sample, no_background_sample)
                    if os.path.exists(checkpoint_file):
                        print(f"{checkpoint_file} is already available. Skipping.")
                        continue
                    else:
                        with open(checkpoint_file, 'w') as file:
                            file.write("placeholder")
                        print(f"Training {checkpoint_file} ...")

                    torch.manual_seed(RANDOM_SEED)
                    explainer = get_model(x_train.shape[1], size, device)

                    f = load_model(dataset, device=device, **model_config)
                    if model_config["model"] == "nn":
                        imputer_model = lambda x: f(x.to(torch.float64)).to(torch.float32)
                    else:
                        imputer_model = lambda x: f(x).unsqueeze(-1).to(device)
                    imputer = MarginalImputerTorch(imputer_model, x_train[0:no_background_sample], device=device)

                    start_time = time()
                    fastshap = train_model(
                        explainer=explainer, 
                        imputer=imputer, 
                        x_train=x_train, 
                        x_test=x_test, 
                        fastshap_n_samples=fastshap_n_sample, 
                        device=device
                    )
                    duration = time() - start_time

                    save_fastshap(
                        fastshap=fastshap,
                        dataset=dataset,
                        model_name=get_task_name(model_config),
                        size=size,
                        fastshap_n_samples=fastshap_n_sample,
                        no_background_sample=no_background_sample,
                        training_duration=duration
                    )