import argparse
import gc
import os
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator, OPTForCausalLM
import transformers
import gc


from collectors import OPTCollector, run_collector
from datautils import get_c4, get_code_search_net, get_wikipedia, get_wikitext2
import utils

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


@torch.no_grad()
def eval_model(model, model_str, enc, dev="cuda"):
    model.eval()

    use_cache = model.config.use_cache
    model.config.use_cache = False

    enc = enc.input_ids

    enc = enc.to(dev)
    nsamples = enc.numel() // model.seqlen

    losses = 0.0
    for i in range(nsamples):
        # if (i % 10) == 0:
        #     print(f"\tPass {i+1} of {nsamples}")
        batch = enc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev)

        out = model(batch)

        if "logits" in out.keys():
            logits = out["logits"]
        else:
            raise ValueError(f"Unknown model out keys:", out.keys())

        loss = nn.CrossEntropyLoss()
        L = loss(logits[:, :-1, :].view(-1, logits.size(-1)), batch[:, 1:].view(-1))

        losses += L.item()

    ppl = math.exp(losses / nsamples)

    model.config.use_cache = use_cache

    outdir = {}
    outdir["ppl"] = ppl

    return outdir


def get_model(model_str):
    def skip(*args, **kwargs):
        pass

    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip

    if "opt" in model_str:
        model = OPTForCausalLM.from_pretrained(
            model_str, torch_dtype="auto", device_map="auto"
        )
        model.seqlen = model.config.max_position_embeddings
    else:
        raise NotImplementedError(f"Unknown model: {model_str}")

    return model


def get_datasets(model_str, langs, seed, nsamples, seqlen):
    datasets = {}
    for lang in langs:
        if lang == "wikitext2":
            datasets[lang] = get_wikitext2(model_str, nsamples=nsamples, seed=seed, seqlen=seqlen)
        elif lang == "c4":
            datasets[lang] = get_c4(model_str, seed=seed, nsamples=nsamples, seqlen=seqlen)
        elif lang == "de":
            datasets[lang] = get_wikipedia(
                model_str, lang, nsamples=nsamples, seed=seed, subsets=(2000, 300), seqlen=seqlen
            )
        elif lang == "fr":
            datasets[lang] = get_wikipedia(
                model_str, lang, nsamples=nsamples, seed=seed, subsets=(2000, 300), seqlen=seqlen
            )
        elif lang == "it":
            datasets[lang] = get_wikipedia(
                model_str, lang, nsamples=nsamples, seed=seed, subsets=(2000, 300), seqlen=seqlen
            )
        elif lang in ("python", "java", "go", "javascript", "php", "ruby"):
            datasets[lang] = get_code_search_net(
                model_str, lang, nsamples=nsamples, seed=seed, seqlen=seqlen
            )
        else:
            raise NotImplementedError(f"Unknown language: [{lang}]")
    return datasets


def main(
    model_str,
    seed,
    train_lang,
    eval_langs,
    nsamples,
    batch_size,
):
    
    device = torch.device('cuda')

    print('retrieving model')
    model = get_model(model_str)
    model.eval()

    model = model.float()

    print(f"Loaded model: {model_str}")
    print(f"\tSeqlen: {model.seqlen}")

    print("Loading datasets...")
    total_nsamples = nsamples
    datasets = get_datasets(model_str, eval_langs, seed, total_nsamples, model.seqlen)

    for lang, (_, testenc) in datasets.items():
        test_outdir = eval_model(model, model_str, testenc)
        print(f"Test  PPL [{lang}]:", test_outdir["ppl"])

    trainenc, _ = datasets[train_lang]

    data_collator = DefaultDataCollator

    train_dataloader = DataLoader(
        trainenc[0:nsamples],
        shuffle=True,
        batch_size=batch_size,
        num_workers=4,
        collate_fn=data_collator,
    )

    torch.cuda.empty_cache()

    print('starting to collect Gram matrices')
    if os.path.exists(f'cs_{model_str}.pt'):
        print('loading Gram matrices from memory')
        Cs = torch.load(f'cs_{model_str}.pt', map_location='cpu')
    else:
        _ = model.to(device)
        Cs = run_collector(OPTCollector(), model, train_dataloader, "cuda")
        torch.save(Cs, f'cs_{model_str}.pt')
    print('done')

    print(model)

    original_param_count = sum(int(p.nelement()) for p in model.parameters())



    var_cutoffs = torch.linspace(0.001, 0.6, 600) 
    for var_cutoff in var_cutoffs:           
        model = get_model(model_str)
        model = model.float()
        model.eval()

        p = 0
        prunable_layers = []
        for name, mod in model.named_modules():
            if isinstance(mod, nn.Linear) and ('out_proj' in name or 'fc2' in name):  # TODO: only works for OPT models
                setattr(mod, 'C', Cs[p].clone())
                prunable_layers.append(mod)
                p+=1

        for n, m in model.named_modules():
            # print(n)
            if isinstance(m, transformers.models.opt.modeling_opt.OPTAttention):
                # m.forward = forward.__get__(m, transformers.models.opt.modeling_opt.OPTAttention)
                C = m.out_proj.C.cuda()
                orig_dtype = C.dtype
                C = C.to(dtype=torch.float64)

                eps = 1e-12
                if (torch.linalg.eigvalsh(C)).min() < -1e-16: eps = 1e-3
                C = C + eps * torch.eye(len(C), dtype=torch.float64, device=C.device)

                num_heads = m.num_heads
                num_units_per_head = m.head_dim

                head_indices = torch.arange(len(C), dtype=torch.int, device=device) // num_units_per_head

                D = utils.mpow(C, -1/2, epsilon=0.)
                D = torch.diag(D)
                D = D * D
                D = 1/D

                order = torch.argsort(D, descending=True)
                importance = D.argsort().argsort()

                C_reordered = C[order][:, order]
                
                head_indices_reordered = head_indices[order]

                _, D_gs = utils.ldl(C_reordered, epsilon=0.)

                del C_reordered
                gc.collect()
                torch.cuda.empty_cache()

                D_gs = utils.rev_cumsum(D_gs / D_gs.sum())

                if D_gs[-1] > var_cutoff: num_kept = len(D_gs)
                else:                     num_kept = torch.where(D_gs < var_cutoff)[0][0].item()

                head_indices_removed = head_indices_reordered[num_kept:]
                
                bincount = torch.bincount(head_indices_removed)
                if len(bincount) <= num_heads:
                    bincount = torch.nn.functional.pad(bincount, (0, num_heads - len(bincount)))
                head_idxs_kept = torch.where(bincount < num_units_per_head)[0]
                
                num_heads_kept = len(head_idxs_kept)
                offsets = head_idxs_kept.view(-1, 1) * num_units_per_head
                idxs_kept = (torch.arange(num_units_per_head, device=device) + offsets).flatten().tolist()

                num_units_removed_per_head = torch.min(bincount).item()
                unit_idxs_removed = []
                for i in range(num_heads):
                    s = num_units_per_head*i
                    e = s + num_units_per_head
                    unit_idxs_removed.extend(
                        (
                            torch.where( importance[s:e].argsort().argsort() < num_units_removed_per_head )[0] + s
                        ).tolist()
                    )

                idxs_kept = list(set(idxs_kept) - set(unit_idxs_removed))
                idxs_removed = list(set(range(num_heads * num_units_per_head)) - set(idxs_kept))

                C = m.out_proj.C.cuda().to(dtype=torch.float64)
                eps = 1e-12
                if (torch.linalg.eigvalsh(C)).min() < -1e-16: eps = 1e-3
                C = C + eps * torch.eye(len(C), dtype=torch.float64, device=C.device)

                C = (C[:, idxs_kept] @ torch.linalg.solve(C[idxs_kept][:, idxs_kept], torch.eye(len(idxs_kept), dtype=torch.float64, device=C.device))).to(orig_dtype)
                W = m.out_proj.weight.data.clone().cuda()
                m.out_proj.weight.data = torch.einsum('ij..., jk->ik...', W, C).cpu()
                m.out_proj.in_features = len(idxs_kept)

                m.k_proj.weight.data = m.k_proj.weight.data[idxs_kept]
                m.v_proj.weight.data = m.v_proj.weight.data[idxs_kept]
                m.q_proj.weight.data = m.q_proj.weight.data[idxs_kept]

                m.k_proj.bias.data = m.k_proj.bias.data[idxs_kept]
                m.v_proj.bias.data = m.v_proj.bias.data[idxs_kept]
                m.q_proj.bias.data = m.q_proj.bias.data[idxs_kept]

                m.num_heads = num_heads_kept
                m.head_dim = num_units_per_head - num_units_removed_per_head
                m.embed_dim = m.num_heads * (num_units_per_head - num_units_removed_per_head)
            
                C = C.cpu()
                W = W.cpu()
                del C, W, m.out_proj.C
                gc.collect()
                torch.cuda.empty_cache()

            if isinstance(m, transformers.models.opt.modeling_opt.OPTDecoderLayer):
                C = m.fc2.C.cuda()
                orig_dtype = C.dtype
                C = C.to(dtype=torch.float64)

                eps = 1e-12
                if (torch.linalg.eigvalsh(C)).min() < -1e-16: eps = 1e-3
                C = C + eps * torch.eye(len(C), dtype=torch.float64, device=C.device)

                D = utils.mpow(C, -1/2, epsilon=0.)
                D = torch.diag(D)
                D = D * D
                D = 1/D

                order = torch.argsort(D, descending=True)
                pivots = torch.argsort(order)
                importance = D.argsort().argsort()

                C_reordered = C[order][:, order]
                _, D = utils.ldl(C_reordered, epsilon=0.)

                del C_reordered
                gc.collect()
                torch.cuda.empty_cache()

                D = utils.rev_cumsum(D / torch.sum(D) )

                if D[-1] > var_cutoff: num_kept = len(D)
                else:                  num_kept = torch.where(D < var_cutoff)[0][0].item()

                idxs_kept = torch.where(importance >= (len(D) - num_kept))[0].tolist()
                idxs_removed = list(set(range(len(C))) - set(idxs_kept))

                
                C = m.fc2.C.cuda().to(dtype=torch.float64)
                eps = 1e-12
                if (torch.linalg.eigvalsh(C)).min() < -1e-16: eps = 1e-3
                C = C + eps * torch.eye(len(C), dtype=torch.float64, device=C.device)

                C = (C[:, idxs_kept] @ torch.linalg.solve(C[idxs_kept][:, idxs_kept], torch.eye(len(idxs_kept), dtype=torch.float64, device=C.device))).to(orig_dtype)
                W = m.fc2.weight.data.clone().cuda()

                m.fc2.weight.data = torch.einsum('ij..., jk->ik...', W, C).cpu()
                m.fc1.weight.data = m.fc1.weight.data[idxs_kept]
                m.fc1.bias.data = m.fc1.bias.data[idxs_kept]

                C = C.cpu()
                W = W.cpu()
                del C, W, m.fc2.C
                gc.collect()
                torch.cuda.empty_cache()


        pruned_param_count = sum(int(p.nelement()) for p in model.parameters())
        pruned_fraction = 1.0 - pruned_param_count / original_param_count

        with torch.no_grad():
            for lang, (_, testenc) in datasets.items():
                test_ppl = eval_model(model, model_str, testenc)["ppl"]
                print("var_cutoff:", var_cutoff, "test PPL:", test_ppl, "sparsity:", pruned_fraction, flush=True)

        del model
        gc.collect()
        torch.cuda.empty_cache()




parser = argparse.ArgumentParser()
parser.add_argument("--model_str", type=str, default="facebook/opt-125m")
parser.add_argument("--seed", type=int, default=100)
parser.add_argument("--eval_langs", type=str, nargs="+", default=["wikitext2"])
parser.add_argument("--train_lang", type=str, default="wikitext2")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--nsamples", type=int, default=128)

args = parser.parse_args()
print(args)

main(**vars(args))
