import torch
from .matmul_had import matmul_hadU
import glog
import multiprocessing as mp

def flat_to_sym(V, N):
    A = torch.zeros(N, N, dtype=V.dtype, device=V.device)
    idxs = torch.tril_indices(N, N, device=V.device)
    A[idxs.unbind()] = V
    A[idxs[1, :], idxs[0, :]] = V
    return A


def sym_to_flat(A):
    N = A.shape[-1]
    idxs = torch.tril_indices(N, N, device=A.device)
    return A[idxs.unbind()]


def register_H_hook(module, device):
    n = module.in_features
    H = torch.zeros(n, n, dtype=torch.float64, device=device)
    mu = torch.zeros(n, dtype=torch.float64, device=device)
    ct = 0

    def H_hook(module, x):
        nonlocal H, mu, ct, n
        x = x[0].reshape(-1, n).to(torch.float64)
        mu.add_(x.sum(dim=0))
        H.addmm_(x.T, x)
        ct += len(x)

    hook = module.register_forward_pre_hook(H_hook)

    def done():
        nonlocal H, mu, ct, hook
        hook.remove()
        return H.cpu(), mu.cpu(), ct

    return done


def block_LDL(H, b): 
    n = H.shape[0]
    assert (n % b == 0)
    m = n // b
    L = torch.linalg.cholesky(H)
    DL = torch.diagonal(L.reshape(m, b, m, b), dim1=0, dim2=2).permute(2, 0, 1)
    D = DL @ DL.permute(0, 2, 1)
    # DLinv = torch.linalg.inv(DL)
    L = L.view(n, m, b)
    for i in range(m):
        # L[:, i, :] = L[:, i, :] @ DLinv[i, :, :]
        L[:, i, :] = torch.linalg.solve(DL[i, :, :], L[:, i, :], left=False)
    L = L.reshape(n, n)
    return (L, D)

def wrap_tokenizer(tokenizer, x, ctx_size):
    return tokenizer(x, return_tensors='pt', truncation=True, padding=True, max_length=ctx_size)

def sample_devset(dataset, tokenizer, size=128, ctx_size=2048, nproc=1):
    devset = torch.zeros((size, ctx_size), dtype=torch.int64)
    saved = 0
    if nproc > 1:
        p = mp.Pool(nproc)
        while saved < size:
            seqs = [(tokenizer, dataset[torch.randint(len(dataset), (size,))]['text'], ctx_size) for _ in range(nproc)]
            tokens = p.starmap(wrap_tokenizer, seqs)
            for i in range(len(tokens)):
                lens = tokens[i].attention_mask.sum(dim=-1)
                good = torch.where(lens == ctx_size)[0]
                if len(good) > 0:
                    if saved + len(good) > size:
                        good = good[:size - saved]
                    devset[saved: saved+len(good)] = tokens[i].input_ids[good]
                    saved += len(good)
                    print(saved)
    else:
        while saved < size:
         tokens = tokenizer(dataset[torch.randint(len(dataset), (size,))]['text'],
                            return_tensors='pt',
                            truncation=True,
                            padding=True,
                            max_length=ctx_size)
         lens = tokens.attention_mask.sum(dim=-1)
         good = torch.where(lens == ctx_size)[0]
         if len(good) > 0:
             if saved + len(good) > size:
                 good = good[:size - saved]
             devset[saved: saved+len(good)] = tokens.input_ids[good]
             saved += len(good)
    return devset


def load_quip(save_name, cb, args, device):
    glog.info(f"loading cached compressed layer from path \"{save_name}\"")
    dict_loaded = torch.load(save_name, map_location=torch.device('cuda', device))
    SU = dict_loaded['SU'].to(device)
    SV = dict_loaded['SV'].to(device)
    Wscale = dict_loaded['Wscale'].to(device)
    Qidxs = dict_loaded['Qidxs'].to(device)
    n, m = len(SU), len(SV)
    hatWr = cb.to(device).by_idxs(Qidxs, packed=(cb.packsz != 1)).view(m, n)
    hatWr = hatWr * Wscale
    del Wscale
    if args.lora_rank > 0:
        A = dict_loaded['A'].to(device)
        B = dict_loaded['B'].to(device)
        hatWr = hatWr + A @ B
        del A, B
    if args.incoh_mode == "had":
        hatW = (matmul_hadU((matmul_hadU(hatWr) * SU).T) * SV).T
    elif args.incoh_mode == "kron":
        hatW = SV.T @ hatWr @ SU
    else: raise NotImplementedError
    del SU, SV
    if args.rescale_WH:
        hatW = hatW / dict_loaded['scaleWH'][None, :].to(device)
    return hatW


def dtype_from_str(str):
    dtype_map = {
        'torch.int64': torch.int64,
        'torch.int32': torch.int32,
        'torch.int16': torch.int16,
        'torch.uint8': torch.uint8,
    }
    return dtype_map[str]
