import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json

from minimal_multitask.data import DATASETS, FileDataset
from minimal_multitask.utils import encode_with_messages_format

from infdist.train.utils import fwd_pass, calc_grad
from infdist.utils import tuple_utils 

from tqdm import tqdm
import argparse
import os
import pickle
import random, math

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="huggyllama/llama-7b")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--dataset", type=str, default="alpaca")
parser.add_argument("--dtype", default="bf16")
parser.add_argument("--prompt_only", action="store_true")
parser.add_argument("--label_only", action="store_true")
parser.add_argument("--only_first_two", action="store_true")  # only use the first two messages
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output_path", type=str)
parser.add_argument("--stat", type=str, default="losses")

args = parser.parse_args()

if os.path.exists(args.output_path):
    print(f'File {args.output_path} already exists. Skipping...')
    exit(0)

torch.manual_seed(args.seed)
if args.dtype == "bf16":
    kwargs = {"torch_dtype": torch.bfloat16}
elif args.dtype == "fp16":
    kwargs = {"torch_dtype": torch.float16}
elif args.dtype == "fp32":
    kwargs = {"torch_dtype": torch.float32}
if "llama" in args.model_name:
    kwargs["attn_implementation"] = "sdpa"

if os.getenv("HF_TOKEN") is not None:
    kwargs["token"] = os.getenv("HF_TOKEN")

model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    **kwargs,
    device_map="auto",  # use multiple gpus if you can
)

if args.tokenizer is not None:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

# load and process train dataset
if args.dataset == "alpaca":
    base_dataset = load_dataset("json", data_files="data/stanford_alpaca_data.jsonl")[
        "train"
    ]
    dataset = base_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 512, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
elif args.dataset == "tulu2":
    base_dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train")
    dataset = base_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
elif args.dataset == "tulu3":
    base_dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
    dataset = base_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
else:
    if os.path.exists(args.dataset):
        base_dataset = load_dataset("json", data_files=args.dataset)["train"]
        def tokenize(x):
            return encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only)
        dataset = base_dataset.map(
            tokenize, num_proc=8, load_from_cache_file=True, keep_in_memory=False
        )
    else:
        raise ValueError(f"Invalid train dataset: {args.dataset}")
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

print(f"Train dataset size: {len(dataset)}")

# construct dataloaders
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

device = next(model.parameters()).device
params = tuple([param for param in model.parameters() if param.requires_grad])
print(f'Using {len(params)} trainable parameters')
losses = []
for batch in tqdm(loader):
    if args.stat == "losses":
        losses.append(fwd_pass(model, batch).item())
    elif args.stat == "grad_norms":
        g = calc_grad(model, params, batch)
        losses.append(sum([(t ** 2).sum().item() for t in g]) ** 0.5)
    else:
        raise ValueError(f"Invalid stat: {args.stat}")

torch.save(losses, args.output_path)


