import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import torch
import numpy as np
import random
import json

from transformers import LlamaForCausalLM, AutoTokenizer, MistralForCausalLM
from src.model_llama import Conv_LlamaForCausalLM
from src.model_mistral import Conv_MistralForCausalLM

from pdb import set_trace as pds

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(f"Using device: {device}")

import os
import sys
def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # set seed
# Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def save_json(data, file_path, indent = 4):
    print(f"save to {file_path}, with length {len(data)}")
    with open(file_path, 'w') as file:
        json.dump(data, file, indent = indent)

# Function to load JSON data from a file
def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)
    
# Define the prompt template and labels
prompt_template = "Review: {text}\nQuestion: Is this review positive or negative?\nAnswer:"
labels = ["Negative", "Positive"]
def prepare_input(example):
    prompt = prompt_template.format(text=example["text"])
    inputs = [prompt + " " + label for label in labels]
    return inputs

def render_instruction(text, template = "llama"):

    if template == "llama":
        messages = [
            {"role": "system", "content": "You are a helpful AI assistant for travel tips and recommendations"},
            {"role": "user", "content": text},
        ]
        rendered = f"<|begin_of_text|><|start_header_id|>{messages[0]['role']}<|end_header_id|>\n\n{messages[0]['content']}<|eot_id|><|start_header_id|>{messages[1]['role']}<|end_header_id|>\n\n{messages[1]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    elif template == "mistral":
        rendered = f"<s>[INST] {text}[/INST]"
    else:
        raise NotImplementedError
    
    return rendered
    

def prepare_instructions(text, template = "llama"):
    '''
    see tokenizer.apply_chat_template function
    '''
    
    if isinstance(text, list):
        return [render_instruction(item, template) for item in text]
    elif isinstance(text, str):
        return render_instruction(text, template)
    else:
        raise TypeError("messages must be either a list or a string")


    

### create label mask

def find_last_subsequence(sequence, subsequence):
    for i in range(len(sequence) - len(subsequence), -1, -1):
        if sequence[i:i+len(subsequence)] == subsequence:
            return i
    return -1

def create_label_mask(input_ids, answer_token_ids):
    mask = torch.zeros_like(input_ids, dtype=torch.bool)
    for i in range(input_ids.shape[0]):
        seq = input_ids[i].tolist()
        pos = find_last_subsequence(seq, answer_token_ids)
        if pos != -1:
            mask[i, pos+len(answer_token_ids):] = True
    return mask

MODEL_MAPPING = {
    "llama3_8b_ins": {"naive": LlamaForCausalLM, "Conv": Conv_LlamaForCausalLM},
    "mistral_7b_ins_v03": {"naive": MistralForCausalLM, "Conv": Conv_MistralForCausalLM}
}

def main():
    args = parse_args()
    # sample_size = args.sample_size
    start_idx = args.start_idx
    end_idx = args.end_idx
    task = args.task
    model_name = args.model_name_or_path
    k = args.k
    max_length = args.max_length

    if "llama" in model_name.lower():
        model_id = "llama3_8b_ins"
        template ="llama"
    elif "mistral" in model_name.lower():
        model_id = "mistral_7b_ins_v03"
        template = "mistral"
    else:
        raise NotImplementedError

    if args.naive:
        model_class = MODEL_MAPPING[model_id]["naive"]
        device_map = "auto"
    else:
        model_class = MODEL_MAPPING[model_id]["Conv"]
        device_map = "sequential"
    
    # Load pre-trained model and tokenizer
    model = model_class.from_pretrained(
        model_name,
        output_attentions=False,
        device_map=device_map,
        attn_implementation="eager"
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    
    # Load IMDB dataset
    full_dataset = load_dataset(task, split="train")
    
    # Create a fixed subset of sample_size examples
    # subset_indices = random.sample(range(len(full_dataset)), sample_size)
    # dataset = full_dataset.select(subset_indices)

    # Shuffle the dataset once
    shuffled_indices = list(range(len(full_dataset)))
    # print(shuffled_indices[:10])
    random.shuffle(shuffled_indices)
    # print(shuffled_indices[:10])
    subset_indices = shuffled_indices[start_idx:end_idx]
    dataset = full_dataset.select(subset_indices)

    #### forward pass
    # Process batches
    true_labels = []
    predicted_labels = []
    correct = 0
    total = 0
    results = []
    for example in tqdm(dataset, desc="Processing batches"):
        prompts = prepare_input(example)
        # pds()
        prompts = prepare_instructions(prompts, template = template)
        # pds()
        inputs = tokenizer(prompts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Create label mask
        answer_token_id = tokenizer.encode("Answer:", add_special_tokens=False)
        label_mask = create_label_mask(input_ids, answer_token_id).to(model.device)
        
        # Create masked labels
        masked_labels = input_ids.clone()
        masked_labels[~label_mask] = -100  # Set non-label tokens to -100 (ignored by loss function)     
        
        with torch.no_grad():
            label_loss = {}
            for idx_label, label in enumerate(labels):
                model_input = {
                    "input_ids": input_ids[idx_label].unsqueeze(0),
                    "attention_mask": attention_mask[idx_label].unsqueeze(0), 
                    "labels": masked_labels[idx_label].unsqueeze(0),
                    # "labels": input_ids[idx_label].unsqueeze(0),
                    }
                if not args.naive:
                    model_input["k"] = k
                outputs = model(**model_input)
                label_loss[label] = outputs.loss.item()
            

        
        predicted_label = 'Negative' if label_loss['Negative'] <= label_loss['Positive'] else 'Positive'
        true_label = "Positive" if example["label"] == 1 else "Negative"

        # print(label_loss)
        # print(f"pre: {predicted_label}")
        # print(f"true: {true_label}")

        true_labels.append(true_label)
        predicted_labels.append(predicted_label)

        correct += predicted_label == true_label
        total += 1

        result = {
            "prompts": prompts,
            "label_loss": label_loss,
            "true_label": true_label,
            "predicted_label": predicted_label,
            "acc": correct,
        }

        results.append(result)
    
    accuracy = correct / total
    print(f"Accuracy: {accuracy:.4f}")
    results.insert(0, {
        "sample_size": total,
        "total_acc": accuracy,
        })
    

    save_fold = f"{model_id}/seq_len{max_length}/generation"
    if args.naive:
        saved_name = f"naive_{start_idx}_{end_idx}.json"
    else:
        saved_name = f"conv_k_{k}_{start_idx}_{end_idx}.json"


    print(f"save saved_states to : out/{save_fold}/{saved_name}")
    ensure_path(f"out/{save_fold}")

    save_json(results, f"out/{save_fold}/{saved_name}")



def parse_args():
    parser = argparse.ArgumentParser(description="text encoder on vision language model")
    parser.add_argument(
        '--task', help='nlp dataset', type = str, default='imdb',
    )

    # parser.add_argument(
    #     '--sample_size', help='number of samples to run infer', type = int, default=10,
    # )

    parser.add_argument(
        '--start_idx', help='start index', type = int, default=0,
    )

    parser.add_argument(
        '--end_idx', help='end index', type = int, default=10,
    )

    parser.add_argument(
        '--model_name_or_path', help='model pretrained weight', type = str, default="mistralai/Mistral-7B-Instruct-v0.3",
        choices=[
            "meta-llama/Meta-Llama-3-8B-Instruct",
            "mistralai/Mistral-7B-Instruct-v0.3"
        ]
    )

    parser.add_argument(
        '--naive', help='whether use naive attn', action="store_true", default=False,
    )

    parser.add_argument(
        '--k', help='number of basis functions for k-conv', type = int, default=5,
    )

    parser.add_argument(
        '--max_length', help='max seq len', type = int, default=4096, # 4096 will explode, need to explore
    )


    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()



