import time
import torch
from tqdm import tqdm
from functools import reduce
from collections import Counter
from torch_incremental_pca import IncrementalPCA

from src.utils import match_module_name, chunk_list, find_equal_values, cycle

from typing import Optional, Union


class PCAHook:
    def __init__(
        self,
        name: str,
        n_components: int,
        sim_thresh: Union[float, torch.Tensor]
    ):
        self.name = name
        self.n_components = n_components
        self.sim_thresh = sim_thresh

        if isinstance(sim_thresh, torch.Tensor) and len(sim_thresh.shape) > 0:
            check1 = sim_thresh.size(0) == n_components or sim_thresh.size(0) == 1
            check2 = len(sim_thresh.shape) == 1
            assert check1 and check2, "if sim_thresh is a tensor with more than 0 dimensions it must have shape (n_components,) or (1,)"

        self.pca = IncrementalPCA(n_components=n_components, copy=True, lowrank=True)

        self.indices = None
        self.converged = torch.zeros((n_components,), dtype=torch.bool)

    def __call__(self, model, input, output):
        # TODO check if clone is necessary
        # TODO add check if indices is None
        previous_components = None
        if hasattr(self.pca, "components_"):
            previous_components = self.pca.components_.clone().detach()

        try:
            states = input.detach()
        except AttributeError:
            states = input[0].detach()
        states = states[self.indices[:, 0], self.indices[:, 1], :]

        if states.size(0) < self.n_components:
            return

        self.pca.partial_fit(states.to(torch.float32))

        if previous_components is not None:
            components = self.pca.components_
            if len(components.shape) == 1:
                components = components.reshape(1, -1)
                previous_components = previous_components.reshape(1, -1)
            # consider as converged if enough components have converged via cossim
            sim = torch.nn.functional.cosine_similarity(components, previous_components)
            self.converged = (sim >= self.sim_thresh)


class HashHook:

    def __init__(self, name: str):
        self.name = name
        self.hashed_inputs = []

    @staticmethod
    def hash_fn(tensor):
        return hash(tuple(tensor.view(-1).tolist()))

    def __call__(self, model, input, output):
        try:
            x = input.detach().cpu()
        except AttributeError:
            x = input[0].detach().cpu()
        self.hashed_inputs.append(self.hash_fn(x))


@torch.no_grad()
def _compute_pca(
    model,
    data_loader,
    rank,
    rho=2,
    early_stop_sim_thresh=0.99,
    early_stop_redist_metric="ratio",
    scale_by_singular_values=False,
    whiten=False,
    target_modules=None,
    ignore_modules=None,
    use_label_mask=True,
    min_batches=1,
    rank_distribution=True,
    log_convergence_stats=False
):

    def _get_metric(pca, metric):
        if metric == "raw":
            return pca.explained_variance_
        elif metric == "ratio":
            return pca.explained_variance_ratio_
        elif metric == "sum":
            return pca.explained_variance_ / pca.explained_variance_.sum()
        elif metric == "max":
            return pca.explained_variance_ / pca.explained_variance_.max()

        else:
            raise ValueError(f"Invalid metric: {metric}")
        
    def _get_rank_distribution(hooks, hook_layer_map, equal_inputs_map, metric, rank_budget, max_components):
        exp_vars = {k: _get_metric(h.pca, metric)[:max_components] for k, h in hooks.items()}
        keys, values = zip(*[(k, c) for k, name in hook_layer_map.items() for c in exp_vars[name]])
        idx = torch.stack(values).argsort(descending=True)
        counts = Counter([keys[i] for i in idx[:rank_budget]])
        counts = {k: counts.get(k, 0) for k in hook_layer_map.keys()} # add layers with 0 rank
        for k, k_hook in equal_inputs_map.items():
            # ensure hook layers have the highest rank if they are equal to another layer
            rank, rank_hook = counts[k], counts[k_hook]
            if rank_hook >= rank:
                continue
            counts[k_hook], counts[k] = rank, rank_hook
        return counts

    assert rho >= 1, "early_stop_rho must be >= 1"
    max_components = round(rank * rho)
    device = next(model.parameters()).device
    training = model.training
    model.eval()

    hooks = {}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if target_modules:
                check = [match_module_name(name, t) for t in target_modules]
                if not any(check):
                    continue
            if ignore_modules:
                check = [match_module_name(name, i) for i in ignore_modules]
                if any(check):
                    continue
            hook = HashHook(name)
            module.register_forward_hook(hook)
            hooks[name] = hook
    rank_budget = len(hooks) * rank

    # forward for one batch to check which layer inputs are equal to avoid unneeded pca calculations
    inputs = {k: v.to(device) for k, v in next(iter(data_loader)).items() if k != "labels"}
    model(**inputs)
    hash_dict = {k: h.hashed_inputs[0] for k, h in hooks.items()}
    equal_inputs_map = {vv: v[0] for v in find_equal_values(hash_dict).values() for vv in v[1:]}
    hooks = {k: PCAHook(k, max_components, early_stop_sim_thresh) for k in hooks.keys() if k not in equal_inputs_map}
    layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}
    for name in layer_hook_map.keys():
        module = reduce(getattr, name.split("."), model)
        module._forward_hooks.clear()

    has_converged_stats = None
    if log_convergence_stats:
        has_converged_stats = [{
            "rank": rank,
            "rho": rho,
            "early_stop_sim_thresh": early_stop_sim_thresh,
            "early_stop_redist_metric": early_stop_redist_metric,
            "scale_by_singular_values": scale_by_singular_values,
            "whiten": whiten,
            "target_modules": target_modules,
            "ignore_modules": ignore_modules,
            "equal_inputs_map": equal_inputs_map
        }]
    
    # start pca calculation
    pbar = tqdm(enumerate(iter(cycle(data_loader))), position=0, leave=False)
    convergence_dict = {k: False for k in hooks.keys()}
    rank_dist = {k: max_components for k in layer_hook_map.keys()}
    for i, inputs in pbar:

        t0 = time.perf_counter()

        mask = inputs["attention_mask"]
        if use_label_mask:
            mask = torch.logical_and(mask.bool(), inputs["labels"] != -100)
        indices = torch.nonzero(mask)
        inputs = {k: v.to(device) for k, v in inputs.items() if k != "labels"}

        for name, hook in hooks.items():
            module = reduce(getattr, name.split("."), model)
            module._forward_hooks.clear()
            # check if all components that are needed for the rank distribution have converged
            if torch.all(hook.converged[:rank_dist[name]]):
                convergence_dict[name] = True
                continue
            convergence_dict[name] = False
            hook.indices = indices
            module.register_forward_hook(hook)

        if all(convergence_dict.values()) and i > min_batches:
            print("exiting - all PCA components have converged.")
            break

        model(**inputs)

        # in case some hooks have to skip the pca calculation because the number of tokens is less than the number of components
        if not all([hasattr(h.pca, "components_") for h in hooks.values()]):
            continue

        if rank_distribution:
            rank_dist = _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, early_stop_redist_metric, rank_budget, max_components)

        step_time = time.perf_counter() - t0

        layer_converged = list(convergence_dict.values()) + [convergence_dict[v] for v in equal_inputs_map.values()]
        pbar.set_description(f"{sum(layer_converged)}/{len(layer_converged)} layers have converged")

        if log_convergence_stats:
            stats = {k: hook.converged.tolist() for k, hook in hooks.items()}
            has_converged_stats.append((stats, step_time))

    pca_dict = {}
    for name, rank in rank_dist.items():
        if rank == 0:
            continue
        hook = hooks[layer_hook_map[name]]
        assert torch.all(hook.converged[:rank]) # this should never happen because we check for convergence
        u = hook.pca.components_[:rank]
        if whiten:
            u /= hook.pca.singular_values_[:rank].sqrt().reshape(-1, 1)
        elif scale_by_singular_values:
            s = hook.pca.singular_values_[:rank]
            s /= s.max()
            u *= s.reshape(-1, 1)
        pca_dict[name] = u

    # objects are torch tensors on the model device
    pca_dict = {k: v.cpu() for k, v in pca_dict.items()}

    exp_vars = None
    if not rank_distribution:
        exp_vars = {k: _get_metric(h.pca, early_stop_redist_metric)[:max_components] for k, h in hooks.items()}
        exp_vars = {**exp_vars, **{k: exp_vars[v] for k, v in equal_inputs_map.items()}}

    # restore model state
    model.train(training)

    return pca_dict, has_converged_stats, exp_vars


@torch.no_grad()
def compute_pca(
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    rank: int,
    rho: int = 2,
    early_stop_sim_thresh: float = 0.99,
    early_stop_redist_metric: str = "ratio",
    scale_by_singular_values: bool = False,
    whiten: bool = False,
    target_modules: Optional[list] = None,
    ignore_modules: Optional[list] = None,
    use_label_mask: bool = True,
    min_batches: int = 1,
    log_convergence_stats: bool = False,
    chunks: Optional[int] = None
):
    chunks = 1 if chunks is None else chunks

    if chunks>1:
        print("Running PCA in chunked mode")
        if not target_modules:
            if not ignore_modules:
                ignore_modules = []
            target_modules = []
            for n, m in model.named_modules():
                if isinstance(m, torch.nn.Linear) and (not any([match_module_name(n, i) for i in ignore_modules])):
                    target_modules.append(n)
        pca_dict = {}
        has_converged_stats = []
        exp_vars = {}
        for i, t in enumerate(chunk_list(target_modules, chunks)):
            print(f"Chunk {i+1} - running PCA for modules: {t}")
            pd, hcs, ev = _compute_pca(
                model=model,
                data_loader=data_loader,
                rank=rank,
                rho=rho,
                early_stop_sim_thresh=early_stop_sim_thresh,
                early_stop_redist_metric=early_stop_redist_metric,
                scale_by_singular_values=scale_by_singular_values,
                whiten=whiten,
                target_modules=t,
                use_label_mask=use_label_mask,
                min_batches=min_batches,
                rank_distribution=False,
                log_convergence_stats=log_convergence_stats
            )
            pca_dict.update(pd)
            has_converged_stats.extend(hcs)
            exp_vars.update(ev)
            torch.cuda.empty_cache()
        if not log_convergence_stats:
            has_converged_stats = None
        # rank distribution
        rank_budget = len(pca_dict) * rank
        keys, _ = zip(*sorted([(k, item) for k, v in exp_vars.items() for item in v], key=lambda x: x[1], reverse=True))
        rank_dist = Counter(keys[:rank_budget])
        pca_dict = {k: v[:rank_dist[k]] for k, v in pca_dict.items() if k in rank_dist}
        return pca_dict, log_convergence_stats
    
    else:
        pca_dict, has_converged_stats, _ = _compute_pca(
            model=model,
            data_loader=data_loader,
            rank=rank,
            rho=rho,
            early_stop_sim_thresh=early_stop_sim_thresh,
            early_stop_redist_metric=early_stop_redist_metric,
            scale_by_singular_values=scale_by_singular_values,
            whiten=whiten,
            target_modules=target_modules,
            ignore_modules=ignore_modules,
            use_label_mask=use_label_mask,
            min_batches=min_batches,
            rank_distribution=True,
            log_convergence_stats=log_convergence_stats
        )
        return pca_dict, has_converged_stats