from typing import Union, Tuple, List
from omegaconf import DictConfig
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.pytorch_utils import Conv1D

def empty_cache(path: str):
    
    files_suffix = os.listdir(path)
    for suffix in files_suffix:
        os.remove(os.path.join(path, suffix))

def get_module(module: nn.Module, module_name: str) -> nn.Module:
    
    for name in module_name.split("."):
        module = getattr(module, name)
    return module

def get_shape(module: Union[nn.Linear, Conv1D]) -> Tuple[int]:
    
    shape = tuple(module.weight.shape)
    return shape[::-1] if isinstance(module, nn.Linear) else shape
    
def cross_entropy(
    logits: torch.FloatTensor,
    labels: torch.LongTensor
):
    if len(logits.shape) == 2:

        return F.binary_cross_entropy_with_logits(logits, labels)

    if len(logits.shape) == 3:

        ans_indice = torch.where(labels != -100)
        
        logits = logits[ans_indice]
        labels = labels[ans_indice]
        
        return F.cross_entropy(logits, labels)

def kl_div(
    refer_logits: torch.FloatTensor,
    logits: torch.FloatTensor,
    labels: torch.LongTensor
) -> torch.Tensor:
    
    if len(logits.shape) == 2:

        refer_probs = F.sigmoid(refer_logits)
        probs = F.sigmoid(logits)

        return (refer_probs * (refer_probs.log() - probs.log()) + (1 - refer_probs) * ((1 - refer_probs).log() - (1 - probs).log())).mean()
    
    if len(logits.shape) == 3:

        ans_indice = torch.where(labels != -100)
        
        refer_logits = refer_logits[ans_indice]
        logits = logits[ans_indice]
        
        refer_log_probs = refer_logits.log_softmax(-1)
        log_probs = logits.log_softmax(-1)
        
        return F.kl_div(
            log_probs,
            refer_log_probs,
            reduction = "batchmean",
            log_target = True
        )
    
def succ_ratios(
    logits: torch.FloatTensor,
    labels: torch.LongTensor
) -> List[float]:
    
    if len(logits.shape) == 2:

        return ((logits > 0) == labels).squeeze(-1).to("cpu").numpy().tolist()
    
    if len(logits.shape) == 3:

        n_corr = (logits.argmax(-1) == labels).sum(-1)
        n_tokens = (labels != -100).sum(-1)
        
        return (n_corr / n_tokens).to("cpu").numpy().tolist()


class Tracer:

    def __init__(
        self,
        module: nn.Module,
        cache_indices: Tuple[torch.LongTensor]
    ):

        def forward_hook(
            module: nn.Module,
            inputs: Tuple[torch.FloatTensor],
            outputs: Tuple[torch.FloatTensor]
        ):
            self.keys = inputs[0][cache_indices].detach()
            
        def backward_hook(
            module: nn.Module,
            inputs_grad: Tuple[torch.FloatTensor],
            outputs_grad: Tuple[torch.FloatTensor]
        ):
            self.values_grad = outputs_grad[0][cache_indices].detach()

        self.handles = [
            module.register_forward_hook(forward_hook),
            module.register_full_backward_hook(backward_hook)
        ]


class TracerDict(dict):
    
    def __init__(
        self,
        model: nn.Module,
        config: DictConfig,
        cache_mask: torch.LongTensor
    ):
        cache_indices = torch.where(cache_mask)
            
        for module_name in config.model.edit_modules:
            module = get_module(model, module_name)
            self[module_name] = Tracer(module, cache_indices)
            
    def __enter__(self):
        return self
            
    def __exit__(self, type, value, traceback):
        for v in self.values():
            for h in v.handles:
                h.remove()