import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json

from minimal_multitask.data import DATASETS, FileDataset
from minimal_multitask.utils import encode_with_messages_format

from infdist import InfluenceManager

from tqdm import tqdm
import argparse
import os
import pickle
import random

parser = argparse.ArgumentParser()
parser.add_argument("--train_dataset", type=str, default="alpaca")
parser.add_argument("--losses_path", type=str)
parser.add_argument("--num_points", type=int)
parser.add_argument("--clusters_path", type=str, default=None)
parser.add_argument("--method", type=str)
parser.add_argument("--lmbda", type=float)
parser.add_argument("--num_samples", type=int)
parser.add_argument("--embds_path", type=str)
parser.add_argument("--output_file", type=str)
parser.add_argument("--prompt_only", action="store_true")
parser.add_argument("--label_only", action="store_true")
parser.add_argument("--only_first_two", action="store_true")  # only use the first two messages

args = parser.parse_args()

print("Starting sensitivity sampling...")

if os.path.exists(args.output_file):
    print(f'File {args.output_file} already exists. Skipping...')
    exit(0)


if os.path.exists(args.train_dataset):
    print(f"Loading training dataset from {args.train_dataset}")
    base_train_dataset = load_dataset("json", data_files=args.train_dataset)["train"]
else:
    raise ValueError(f"Invalid train dataset: {args.train_dataset}")


import torch
import os
import numpy as np
import sys
sys.path.append('./ablations')
from infdist.distil.functional import _landmark_p_to_full_krr, _landmark_rows_leverage
from utils import _estimate_projected_grads_krrf

device = 'cuda:0'
print(f"Loading embeddings from {args.embds_path}")
embds = torch.load(args.embds_path, device)
print(f"Loading losses from {args.losses_path}")
losses = torch.tensor(torch.load(args.losses_path), device=device)
assert len(losses) == len(embds)
print(f"Loaded {len(embds)} embeddings and losses")

if args.method == 'lowrank':
    print("Computing lowrank approximation...")
    landmark_idx = _landmark_rows_leverage(
        embds,
        l=args.num_points,
        seed=43,
        rank=None,
        oversample=5,
        n_iter=2,
    )
    print("Computing approximate embeddings...")
    apx_embds = _estimate_projected_grads_krrf(
        embds,
        landmark_idx,
        embds[landmark_idx],
        damp=0.01
    )
    print("Computing approximate losses...")
    apx_losses = _landmark_p_to_full_krr(
        landmark_idx,
        embds,
        losses[landmark_idx],
        damp=0.01,
    )

elif args.method == 'cluster':
    print(f"Loading clusters from {args.clusters_path}")
    clusters_loaded = torch.load(args.clusters_path, device)
    assert len(clusters_loaded[1]) == len(embds)
    landmark_idx = clusters_loaded[0].cpu().numpy()
    assert len(landmark_idx) == args.num_points
    
    print("Computing distances between landmarks and all points...")
    dists = torch.cdist(embds[landmark_idx], embds) # shape: (l, n)
    mapped_idx = torch.argmin(dists, dim=0).cpu().numpy() # shape: (n, )
    apx_embds = embds[landmark_idx[mapped_idx]]
    apx_losses = losses[landmark_idx[mapped_idx]]

elif args.method == 'ideal':
    apx_embds = embds
    apx_losses = losses

else:
    raise ValueError(f'Invalid method: {args.method}')

apx_losses[apx_losses < 0] = 0 # make sure there are no negative losses


print("Computing sampling probabilities...")
r = torch.norm(embds - apx_embds, dim=1)
s = apx_losses + args.lmbda * r
p = s / s.sum()

print(f"Sampling {args.num_samples} points...")
# select num_samples with probs s
selected_idx = torch.multinomial(p, args.num_samples, replacement=False)
full_w = torch.zeros(len(losses), device=device, dtype=torch.float32)
full_w[selected_idx] = 1 / (p[selected_idx] * args.num_samples)

saved_instances = full_w.nonzero().squeeze().detach().cpu().numpy().tolist()
saved_scores = full_w[saved_instances].detach().cpu().numpy().tolist()

print("Building output dataset...")
output_dataset = []
for i, train_idx in enumerate(saved_instances):
    instance = base_train_dataset[train_idx]
    instance["influence_score"] = saved_scores[i] if type(saved_scores[i]) is float else saved_scores[i].item()
    output_dataset.append(instance)

if not os.path.exists(os.path.dirname(args.output_file)):
    os.makedirs(os.path.dirname(args.output_file))

print("Shuffling dataset...")
# make sure the dataset is properly shuffled
random.shuffle(output_dataset)

print(f"Writing {len(output_dataset)} instances to {args.output_file}")
with open(args.output_file, "w") as f:
    for instance in output_dataset:
        f.write(json.dumps(instance) + "\n")

print("Done!")