import numpy as np
import torch

from model_correction.base_correction_method import LitClassifier, Freeze
from utils.cav import compute_cav

class ClarcMulti(LitClassifier):
    """ ClArC with multiple CAVs for different artifacts

    Args:
        LitClassifier (_type_): _description_
    """
    def __init__(self, model, config, **kwargs):
        super().__init__(model, config, **kwargs)

        self.std = None
        self.layer_name = config["layer_name"]
        self.dataset_name = config["dataset_name"]
        self.model_name = config["model_name"]

        assert "artifact_sample_ids" in kwargs.keys(), "artifact_sample_ids have to be passed to ClArC correction methods"
        assert "sample_ids" in kwargs.keys(), "all sample_ids have to be passed to ClArC correction methods"
        assert "class_names" in kwargs.keys(), "class_names has to be passed to ClArC correction methods"
        assert "mode" in kwargs.keys(), "mode has to be passed to ClArC correction methods"

        artifact_sample_ids = kwargs["artifact_sample_ids"]
        self.sample_ids = kwargs["sample_ids"]
        self.class_names = kwargs["class_names"]

        self.direction_mode = config["direction_mode"]
        self.mode = kwargs['mode']

        cav_config = config["cav_config"]
        artifact_type = ""
        artifact_extension = f"_{artifact_type}-{config['p_artifact']}" if artifact_type else ""
        artifact_extension += f"-{config['lsb_factor']}" if artifact_type == "lsb" else ""
        artifact_extension += "_bd" if config.get("use_backdoor_model", False) else ""
        self.path = f"{config['dir_precomputed_data']}/global_relevances_and_activations/{self.dataset_name}{artifact_extension}/{self.model_name}"

        cav_data = {
            artifact_type: self.compute_cav(self.mode, cav_scope, artifact_sample_ids[artifact_type]) 
            for artifact_type, cav_scope in cav_config.items()
        }
        
        self.cavs = {artifact_type: c[0] for artifact_type, c in cav_data.items()}
        self.mean_length = {artifact_type: c[1] for artifact_type, c in cav_data.items()}
        self.mean_length_targets = {artifact_type: c[2] for artifact_type, c in cav_data.items()}

        hooks = []
        for n, m in self.model.named_modules():
            if n == self.layer_name:
                print("Registered forward hook.")
                hooks.append(m.register_forward_hook(self.clarc_hook))
        self.hooks = hooks

    def compute_cav(self, mode, cav_scope, artifact_sample_ids):
        vecs = []
        sample_ids = []

        path = self.path
        _cav_scope = cav_scope or self.class_names

        for class_id in _cav_scope:
            path_precomputed_activations = f"{path}/{self.layer_name}_class_{class_id}_all.pth"
            print(f"reading precomputed relevances/activations from {path_precomputed_activations}")
            data = torch.load(path_precomputed_activations)
            if data['samples']:
                sample_ids += data['samples']
                vecs.append(torch.stack(data[mode], 0))

        vecs = torch.cat(vecs, 0)
        sample_ids = np.array(sample_ids)

        # Only keep samples that are in self.sample_ids (usually training set)
        all_sample_ids = np.array(self.sample_ids)
        filter_sample = np.array([id in all_sample_ids for id in sample_ids])
        vecs = vecs[filter_sample]
        sample_ids = sample_ids[filter_sample]

        target_ids = np.array(
            [np.argwhere(sample_ids == id_)[0][0] for id_ in artifact_sample_ids if
             np.argwhere(sample_ids == id_)])
        targets = np.array([1 * (j in target_ids) for j, x in enumerate(sample_ids)])
        X = vecs.detach().cpu().numpy()
        X = X.reshape(X.shape[0], -1)
        cav = compute_cav(
            X, targets, cav_type=self.direction_mode
        )

        mean_length = (vecs[targets == 0].flatten(start_dim=1)  * cav).sum(1).mean(0)
        mean_length_targets = (vecs[targets == 1].flatten(start_dim=1) * cav).sum(1).mean(0)

        return cav, mean_length, mean_length_targets

    def clarc_hook(self, m, i, o):
        pass

    def configure_callbacks(self):
        return [Freeze(
            self.layer_name
        )]