from src.models.modeling_gpt2 import ExtendedGPT2LMHeadModel
from src.utils.my_tqdm import mytqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter
import torch.nn as nn
from transformers import DataCollatorForLanguageModeling
from typing import Optional
import os

class DataLoaderForTaskApprox(DataLoader):
    def __init__(self, kernel: list[torch.Tensor], kernel_grad: list[torch.Tensor], indices: list[torch.Tensor], 
                 dataloader: Optional[DataLoader]=None, dataset: Optional[Dataset]=None, 
                 *args, **kwargs):
        if dataloader is not None:
            self.dataloader = dataloader
        elif dataset is not None:
            self.dataloader = super().__init__(dataset, *args, **kwargs)
        else:
            raise ValueError("Either dataloader or dataset should be provided.")
        self.kernel = kernel
        self.kernel_grad = kernel_grad
        self.indices = indices

    def __iter__(self):
        out = iter(self.dataloader)
        for i, batch in enumerate(out):
            batch["kernel"] = self.kernel[i].cuda()
            batch["kernel_grad"] = self.kernel_grad[i].cuda()
            batch["indices"] = self.indices[i].cuda()
            yield batch
    
    def __getattr__(self, attr): 
        return getattr(self.dataloader, attr)

class AttnKernelGradNet(nn.Module):
    def __init__(self, model: ExtendedGPT2LMHeadModel, topk: int):
        super().__init__()
        self.model = model
        self.topk = topk
        
    def forward(self, *args, **kwargs):
        out = self.model(output_attentions=["kernel"], *args, **kwargs)
        loss = out.loss
        loss.backward()

        kernel_grad = torch.stack([layer_attn["kernel"].grad for layer_attn in out.attentions], dim=1)
        kernel = torch.stack([layer_attn["kernel"].detach() for layer_attn in out.attentions], dim=1)

        _kernel_grad = kernel_grad.flatten(-2, -1)
        _abs_kernel_grad = torch.abs(_kernel_grad)
        _kernel = kernel.flatten(-2, -1)
        topk_indices = _abs_kernel_grad.topk(self.topk, dim=-1).indices
        
        topk_kernel = torch.gather(_kernel, -1, topk_indices)
        topk_kernel_grad = torch.gather(_kernel_grad, -1, topk_indices)

        return {"kernel": topk_kernel, "kernel_grad": topk_kernel_grad, "indices": topk_indices}
                                            
def create_dataset_for_task_approx(dataloader: DataLoader,
                                   target_model: nn.Module, 
                                   save_path: Optional[str]) -> DataLoader:
    
    if (save_path is not None) and os.path.exists(save_path):
        kernel, kernel_grad, indices = torch.load(save_path)

    else:
        for block in target_model.transformer.h:
            block.attn.activate_or_deactivate_kernel_grad(True)
        
        attn_kernel_grad_net = AttnKernelGradNet(target_model, topk=50)
        attn_kernel_grad_net = nn.DataParallel(attn_kernel_grad_net)

        kernel = []
        kernel_grad = []
        indices = []
        for batch in mytqdm(dataloader, message="Computing kernel and kernel_grad"):
            out = attn_kernel_grad_net(**batch)
            kernel.append(out["kernel"].cpu())
            kernel_grad.append(out["kernel_grad"].cpu())
            indices.append(out["indices"].cpu())
        
        if save_path is not None:
            torch.save((kernel, kernel_grad, indices), save_path)
        
        for block in target_model.transformer.h:
            block.attn.activate_or_deactivate_kernel_grad(False)

    new_dataloader = DataLoaderForTaskApprox(kernel, kernel_grad, indices, dataloader=dataloader)

    return new_dataloader