import torch
torch.set_printoptions(sci_mode=False)
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from minimal_multitask.utils import encode_with_messages_format
import os, re


from copy import deepcopy
from torch.autograd.forward_ad import dual_level, make_dual, unpack_dual
from infdist.utils.proj import _cosntruct_projector, _project_grad
from tqdm import tqdm

import argparse


def _get_params(model, param_regex, skip_embd, long_num_params=False):
    if param_regex is not None:
        param_regex = re.compile(param_regex)

    params = []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if skip_embd and ('embed' in n or 'head' in n):
            continue
        if param_regex is None or param_regex.search(n) is not None:
            params.append(p)
    if long_num_params:
        print(f'Found {len(params)} params.')
    return tuple(params)

@torch.no_grad()
def create_jvp_store(model, samples, num_blocks, num_tangents=1, proj_dim=8192, proj_num_parts=1, param_regex=None, skip_embd=True, seed=43):
    device = next(model.parameters()).device

    model.cpu()
    org_model = model
    model = deepcopy(org_model)
    model = model.to(device)

    model.model.layers = model.model.layers[:num_blocks]
    model.lm_head = torch.nn.Identity()

    params = _get_params(model, param_regex, skip_embd)
    org_params = tuple([param.clone() for param in params])

    torch.manual_seed(seed)
    assert proj_num_parts == 1, "proj_num_parts > 1 not supported yet"

    projector = _cosntruct_projector(
        full_dim=num_tangents*model.model.layers[-1].mlp.down_proj.out_features//proj_num_parts,
        proj_dim=proj_dim//proj_num_parts,
        seed=seed,
        device=device,
        dtype=torch.float32
    )

    def _replace_params(new_params):
        with torch.no_grad():
            for param, new_param in zip(params, new_params):
                param.copy_(new_param)

    def _jvp(batch, vec):
        with dual_level():
            dual_params = [make_dual(p, v) for p, v in zip(org_params, vec)]
            _replace_params(dual_params)
            output_dual = model(**{k: v for k, v in batch.items() if k != 'labels'})
            output_tangent = unpack_dual(output_dual.logits.squeeze()[-1]).tangent.squeeze()
            return output_tangent
    
    print("Creating JVP store")
    tangents = [[] for _ in samples]
    for i in range(num_tangents):
        print(f"Creating tangent {i+1}/{num_tangents}")
        vec = tuple([torch.randn_like(param) for param in params])
        for j, sample_j in tqdm(enumerate(samples), total=len(samples)):
            sample_j = {k: v.to(device).unsqueeze(0) for k, v in sample_j.items()}
            tangent = _jvp(sample_j, vec)
            tangents[j].append(tangent.cpu())

    print("Projecting")
    embds = []
    for sample_tangents in tqdm(tangents):
        embd = _project_grad(projector, tuple([st.to(device) for st in sample_tangents]), num_parts=1)
        embds.append(embd.detach())

    del model, params, org_params, vec, tangents, sample_tangents
    org_model.to(device)

    return embds



model_dtype = torch.bfloat16

parser = argparse.ArgumentParser(description='JVP Embedding Script')
parser.add_argument('--model_name', type=str, required=True, help='Model name or path')
parser.add_argument('--train_dataset', type=str, required=True, help='Path to training dataset')
parser.add_argument('--index_path', type=str, required=True, help='Path to index file')
parser.add_argument('--batch_size', type=int, default=1,
                    help='Batch size for training')
parser.add_argument('--prompt_only', action='store_true',
                    help='Use only prompts')
parser.add_argument('--label_only', action='store_true',
                    help='Use only labels')
parser.add_argument('--only_first_two', action='store_true',
                    help='Use only first two examples')
parser.add_argument('--seed', type=int, default=False, help='Random seed')
parser.add_argument('--num_blocks', type=int, default=4, help='Number of blocks to use')
parser.add_argument('--num_tangents', type=int, default=2, help='Number of tangents to use')
parser.add_argument('--proj_dim', type=int, default=4096, help='Projection dimension')

args = parser.parse_args()

torch.manual_seed(args.seed)

model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    device_map="cuda:0",  # use multiple gpus if you can
    use_cache=False,
    token=os.getenv("HF_TOKEN"),
    torch_dtype=model_dtype,
    attn_implementation='eager'
)

tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

if os.path.exists(args.train_dataset):
    base_train_dataset = load_dataset("json", data_files=args.train_dataset)["train"]
    def tokenize(x):
        return encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only)
    train_dataset = base_train_dataset.map(
        tokenize, num_proc=1, load_from_cache_file=False, keep_in_memory=False
    )
else:
    raise ValueError(f"Invalid train dataset: {args.train_dataset}")
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

print(f"Train dataset size: {len(train_dataset)}")

store = create_jvp_store(
    model,
    train_dataset,
    num_blocks=args.num_blocks,
    num_tangents=args.num_tangents,
    proj_dim=args.proj_dim,
    proj_num_parts=1,
    param_regex=None,
    skip_embd=False,
    seed=args.seed
)

os.makedirs(os.path.dirname(args.index_path), exist_ok=True)

store = torch.stack(store)
store = store / store.norm(dim=1, keepdim=True)
torch.save(store, args.index_path)