import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json

from minimal_multitask.data import DATASETS, FileDataset
from minimal_multitask.utils import encode_with_messages_format

import argparse
import os
import random

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="huggyllama/llama-7b")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--train_dataset", type=str, default="alpaca")
parser.add_argument("--eval_dataset", type=str)
parser.add_argument("--index_path", type=str)
# be careful with this one! leaks test data into train set so we can sanity check the retrieval
parser.add_argument("--dtype", default="bf16")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--prompt_only", action="store_true")
parser.add_argument("--label_only", action="store_true")
parser.add_argument("--only_first_two", action="store_true")  # only use the first two messages

parser.add_argument("--proj_dim", type=int, default=8192)
parser.add_argument("--mask_numel", type=int, default=0)
parser.add_argument("--proj_type", type=str, default='had')
parser.add_argument("--param_regex", type=str, default=None)
parser.add_argument("--no_optim", action="store_true")
parser.add_argument("--probabilistic", action="store_true")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_landmarks", type=int)
parser.add_argument("--num_samples", type=int)
parser.add_argument("--normalize_grads", action="store_true")

parser.add_argument("--output_file", type=str)

args = parser.parse_args()

if os.path.exists(args.output_file):
    print(f'File {args.output_file} already exists. Skipping...')
    exit(0)

if args.param_regex is not None and args.param_regex.strip().lower() in ['', 'none']:
    args.param_regex = None

torch.manual_seed(args.seed)
if args.dtype == "bf16":
    kwargs = {"torch_dtype": torch.bfloat16}
elif args.dtype == "fp16":
    kwargs = {"torch_dtype": torch.float16}
elif args.dtype == "fp32":
    kwargs = {"torch_dtype": torch.float32}
if "llama" in args.model_name:
    kwargs["attn_implementation"] = "sdpa"

if os.getenv("HF_TOKEN") is not None:
    kwargs["token"] = os.getenv("HF_TOKEN")

model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    **kwargs,
    device_map="auto",  # use multiple gpus if you can
)

if args.tokenizer is not None:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

# load and process train dataset
if args.train_dataset == "alpaca":
    base_train_dataset = load_dataset("json", data_files="data/stanford_alpaca_data.jsonl")[
        "train"
    ]
    train_dataset = base_train_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 512, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
elif args.train_dataset == "tulu2":
    base_train_dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train")
    train_dataset = base_train_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
elif args.train_dataset == "tulu3":
    base_train_dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
    train_dataset = base_train_dataset.map(
        lambda x: encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only), num_proc=16
    )
else:
    if os.path.exists(args.train_dataset):
        base_train_dataset = load_dataset("json", data_files=args.train_dataset)["train"]
        def tokenize(x):
            return encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only)
        train_dataset = base_train_dataset.map(
            tokenize, num_proc=8, load_from_cache_file=True, keep_in_memory=False
        )
    else:
        raise ValueError(f"Invalid train dataset: {args.train_dataset}")
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# test dataset - mostly handled in data.py
if args.eval_dataset in DATASETS:
    test_dataset = DATASETS[args.eval_dataset](tokenizer).get_all_test_prompts(
        seed=args.seed, prompt_only=args.prompt_only, response_only=args.label_only
    )
elif os.path.exists(args.eval_dataset):
    # if a file is given, we assume it's a dataset
    test_dataset = FileDataset(args.eval_dataset, tokenizer).get_all_test_prompts(
        prompt_only=args.prompt_only, response_only=args.label_only
    )
else:
    raise ValueError(f"Invalid dataset: {args.dataset}")

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# construct dataloaders
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
eval_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
assert args.index_path is not None and os.path.exists(args.index_path)


if not args.no_optim:
    from transformers import Trainer
    # use a fake trainer to create the optimizer
    training_args = torch.load(os.path.join(args.model_name, 'training_args.bin'))
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer
    )
    trainer.create_optimizer()
    optim = trainer.optimizer
    print(f'constructed optimizer {optim}')

    optim.load_state_dict(torch.load(os.path.join(args.model_name, 'optimizer.pt'), map_location='cpu'))
    print(f'loaded optimizer state dict')

device = next(model.parameters()).device

from infdist import select_per_target
print('Using influence distillation')

selected_idx, weights, landmark_idx, landmark_store, target_store, p_L = select_per_target(
    model=model,
    train_loader=train_data_loader,
    target_loader=eval_data_loader,
    optimizer=optim,
    embeddings=torch.load(args.index_path, device),
    k=args.num_samples,
    skip_embd=False,
    proj_dim=args.proj_dim,
    param_mask_numel=args.mask_numel if args.mask_numel > 0 else None,
    probabilistic=args.probabilistic,
    proj_num_parts=1,
    proj_subset=args.param_regex, # 'random-{}' or a key like 'o_proj', or None
    proj_type=args.proj_type,
    num_landmarks=args.num_landmarks,
    selection='topk', # 'topk' or 'water-filling'
    seed=args.seed,
    damp=0.1,
    return_weights=True,
)
full_w = torch.zeros(len(train_dataset), device=device, dtype=torch.float32)
full_w[selected_idx] = torch.tensor(weights, device=device, dtype=torch.float32)

saved_instances = full_w.nonzero().squeeze().detach().cpu().numpy().tolist()
saved_scores = full_w[saved_instances].detach().cpu().numpy().tolist()

output_dataset = []
for i, train_idx in enumerate(saved_instances):
    instance = base_train_dataset[train_idx]
    instance["influence_score"] = saved_scores[i] if type(saved_scores[i]) is float else saved_scores[i].item()
    output_dataset.append(instance)

if not os.path.exists(os.path.dirname(args.output_file)):
    os.makedirs(os.path.dirname(args.output_file))

# make sure the dataset is properly shuffled
random.shuffle(output_dataset)

with open(args.output_file, "w") as f:
    for instance in output_dataset:
        f.write(json.dumps(instance) + "\n")