import os
import logging
import re

import torch
from mkdir_p import mkdir_p
import numpy as np

from .utils import set_random_seed
from spurious_ml.datasets import add_spurious_correlation
from spurious_ml.variables import get_file_name
from spurious_ml.influence_utils.influence import first_order_group_influence, calc_influence_single


logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
                    level=logging.WARNING, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

base_model_dir = './models/train_classifier/'
save_model_dir = './models/group_influence/'

def run_influence(auto_var):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    _ = set_random_seed(auto_var)
    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var("dataset")
    model_name = auto_var.get_variable_name("model")
    if "MLP" in model_name:
        trnX, tstX = trnX.reshape(len(trnX), -1), tstX.reshape(len(tstX), -1)
        is_img_data = False
    else:
        is_img_data = True

    model_path = os.path.join(base_model_dir, auto_var.get_var('model_path'))

    result = {
        'spurious_ind': spurious_ind,
        'model_path': os.path.join(save_model_dir, get_file_name(auto_var) + ".pt")
    }

    multigpu = False
    model = auto_var.get_var("model", trnX=trnX, trny=trny, multigpu=multigpu, device=device)
    model.tst_ds = (tstX, tsty)

    model.load(model_path)

    model.model.eval()

    ds_name = auto_var.get_variable_name("dataset")
    if "mnist" in ds_name:
        template = r"mnist(?P<version>v[0-9a-z]+)?-(?P<sp_counts>[0-9]+)-(?P<cls_no>[0-9])-(?P<seed>[0-9]+)"
        groupdict = re.fullmatch(template, ds_name).groupdict()
        print(groupdict)
        cls_no = int(groupdict['cls_no'])
        seed = int(groupdict['seed'])
        version = groupdict['version']

        influences = np.array(calc_influence_single(
            model.model,
            torch.from_numpy(same_cls_tstX.reshape(1, -1)).float(),
            torch.from_numpy(same_yy).long(),
            torch.from_numpy(trnX).float(),
            torch.from_numpy(trny).long(),
            recursion_depth=2000,
            r=10,
        ))
        result['influences_same_cls'] = influences

    return result
