import torch

import json
import os.path
import random

import pandas as pd
import fire
import wandb
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import os

import sys
sys.path.append('../src')
sys.path.append('../lut-quant/')

import preprocessing


def process_sample(model, tokenizer, dataset_name, batch, skip_grad=False, params=None):
    if params is None:
        params = tuple([param for param in model.parameters()])

    # Get the prompt and response separately
    preprocess_fn = getattr(preprocessing, f'{dataset_name}_preprocessing_function')
    prompts = [preprocess_fn({'input': inp, 'output': ''})['prompt'] for inp in batch['inp']]
    responses = [preprocess_fn({'input': inp, 'output': out})['response'] for inp, out in zip(batch['inp'], batch['label'])]
    full_text = [prompt + response for prompt, response in zip(prompts, responses)]
    
    # Tokenize prompt and full text to get label masks
    prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    model_inputs = tokenizer(full_text, return_tensors="pt", padding=True, truncation=True).to("cuda")
    
    # Create labels where only response tokens contribute to loss
    labels = model_inputs['input_ids'].clone()
    
    # Mask out prompt tokens (set to -100 so they don't contribute to loss)
    for i, prompt_len in enumerate(prompt_inputs['attention_mask'].sum(dim=1)):
        labels[i, :prompt_len] = -100
    
    # Compute forward pass and loss
    outputs = model(**model_inputs, labels=labels)
    loss = outputs.loss
    
    if not skip_grad:
        grads = torch.autograd.grad(loss, params)
    else:
        grads = None
    
    return loss, grads

def compute_loss_grad(model, tokenizer, dataset_name, batch_size=1, skip_grad=False, param_key='down_proj'):
    assert batch_size == 1, "batch_size must be 1"
    if dataset_name == 'sql':
        dataset = load_dataset('json',
                               data_files="./data/sql/train.jsonl",
                               split="train")
        dataset = dataset.map(
            lambda example: {
                'inp': example['messages'][0]['content'],
                'label': example['messages'][1]['content'],
            }, remove_columns=['messages'])
    elif dataset_name == 'viggo':
        dataset = load_dataset('GEM/viggo', split='train')
        dataset = dataset.map(
            lambda example: {
                'inp': example['target'],
                'label': example['meaning_representation']
            })
    elif dataset_name == 'gsm8k':
        dataset = load_dataset('gsm8k', 'main', split='train')
        dataset = dataset.map(
            lambda example: {
                'inp': example['question'],
                'label': example['answer']
            })
    else:
        raise ValueError(f'Unknown dataset name: {dataset_name}')

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    losses = []
    grads = None


    if not skip_grad:
        from infdist.utils.proj import _cosntruct_projector, _project_grad
        from infdist.utils import tuple_utils
        params = tuple([param for name, param in model.named_parameters() if param_key in name])
        proj_dim = 8192
        proj_type = 'rad'
        proj_num_parts = 1

        projector = _cosntruct_projector(
            full_dim=tuple_utils.numel(params) // proj_num_parts,
            proj_dim=proj_dim,
            seed=42,
            device='cuda:0',
            dtype=torch.float32 if proj_type == 'rad' else params[0].dtype,
            proj_type=proj_type
        )

        grads = []

    for sample in tqdm(dataloader):
        
        if not skip_grad:
            loss, full_grad = process_sample(model, tokenizer, dataset_name, sample, skip_grad=skip_grad, params=params)
        else:
            with torch.no_grad():
                loss, _ = process_sample(model, tokenizer, dataset_name, sample, skip_grad=skip_grad)
        losses.append(loss.item())
        # breakpoint()

        if not skip_grad:
            grads.append(_project_grad(
                projector=projector,
                grads_tuple=full_grad,
                num_parts=proj_num_parts
            ))


    return torch.tensor(losses), torch.stack(grads) if grads is not None else None


def load_model(model_path, precision):
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.pad_token = tokenizer.eos_token
    dtype = torch.bfloat16 if precision in ['bf16', 'bfloat16'] else torch.float32
    
    print('Loading base model...')
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map='auto',
        torch_dtype=dtype,
        trust_remote_code=True,
        use_auth_token=True,
        # attn_implementation='eager',
        attn_implementation='sdpa',
        use_cache=True
    )
    model.eval()
    # model.cuda()

    model.to(dtype=dtype)
    return model, tokenizer

def main(
        model_path,
        dset,
        precision='bf16',
        skip_grad=False
):
    os.environ["MODEL"] = model_path
    batch_size = 1
    
    model, tokenizer = load_model(model_path, precision)
    losses, grads = compute_loss_grad(model, tokenizer, dset, batch_size=batch_size, skip_grad=skip_grad)
    torch.save(losses, f'{dset}_losses2.pt')
    
    if not skip_grad:
        torch.save(grads, f'{dset}_grad2.pt')

    print('Done!')


if __name__ == '__main__':
    fire.Fire(main)
