import os
import re
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import json
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import GenerationConfig
from datasets import load_dataset
from safeft.config import *
from safeft.prepare_dataset import *

generation_config = GenerationConfig(
    max_new_tokens=200,
    # Using default temperature and sampling parameters
)

def get_model_path(args, father_path='placeholder', base=None):
    # get lora model path
    if base is not None:
        path = f"./{father_path}/{args.model_path_name}/{base}"
    else:
        path = f"./{father_path}/{args.model_path_name}/model"
    path = path + args.get('name_suffix', '')
    return path

def get_lora_path_list(args, father_path='finetune_model'):
    #if args is path:, check if is string first
    if isinstance(args, str) and os.path.isdir(args):
        finetune_model_path = args
    finetune_model_path = get_model_path(args, father_path=father_path)
    checkpoints = [os.path.join(finetune_model_path, f) for f in os.listdir(finetune_model_path) if f.startswith('checkpoint-')]
    checkpoints = sorted(checkpoints, key=lambda x: float(os.path.basename(x).split('-')[1]))
    return checkpoints


def generate_summary(tokenizer, model, dialogues, system=None, batch_size=None):
    if not isinstance(dialogues, list):
        raise ValueError("Input should be a list of dialogues.")
    
    summaries = []
    if batch_size:
        # Use batched responses if batch_size is specified
        summaries = get_batched_responses(tokenizer, model, dialogues, system=system, batch_size=batch_size)
    else:
        for dialogue in tqdm(dialogues):
            # Format the input explicitly
            output_text = get_response(tokenizer, model, dialogue, system=system, only_output=True)
            summaries.append(output_text.strip())
    return summaries

def get_gsm8k_answer(tokenizer, model, questions, system=None,batch_size=None):
    """
    Get the answer for a given question from the GSM8K dataset.
    The question should be a string.
    """
    if not isinstance(questions, list):
        raise ValueError("Questions should be a list of strings.")

    if batch_size is None:
        answers = []
        for question in tqdm(questions):
            if not isinstance(question, str):
                raise ValueError("Each question should be a string.")
            output_text = get_response(tokenizer, model, question,system=system, only_output=True, max_length=512)
            answers.append(output_text.strip())
    else:
        answers = get_batched_responses(tokenizer, model, questions, system=system, only_output=True, max_length=512, batch_size=batch_size)
        answers = [answer.strip() for answer in answers]

    return answers

def get_arc_answer(tokenizer, model, prompts, system=None,batch_size=None):
    if not isinstance(prompts, list):
        raise ValueError("Prompts should be a list of strings.")

    answers = []
    if batch_size is None:
        for prompt in tqdm(prompts):
            output_text = get_response(tokenizer, model, prompt, system=system, only_output=True, max_length=10)
            answer_key = output_text[0]
            answers.append(answer_key)
    else:
        answers = get_batched_responses(tokenizer, model, prompts, system=system, only_output=True, max_length=10, batch_size=batch_size)
        answers = [answer[0] for answer in answers]
    return answers


def get_response(tokenizer, model, prompt, system=None, only_output=True, max_length=200, return_ids=False,chat_template=True):
    tokenizer.pad_token_id = tokenizer.eos_token_id
    generation_config = GenerationConfig(max_new_tokens=max_length, pad_token_id=tokenizer.pad_token_id)
    if chat_template:
        prompt = format({
            "system": system,
            "prompt": prompt
        }, tokenizer=tokenizer)
    else:
        prompt = prompt
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            generation_config=generation_config
        )
    if only_output: # Extract only the generated part
        ids = outputs[0][inputs["input_ids"].shape[1]:]
    else:
        ids = outputs[0]
    response = tokenizer.decode(ids, skip_special_tokens=True)
    if return_ids:
        return response, inputs.input_ids.detach().cpu(), ids.detach().cpu()
    else:
        return response

def get_batched_responses(tokenizer, model, prompts, system=None, only_output=True, max_length=200, chat_template=True, batch_size=8):
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    generation_config = GenerationConfig(max_new_tokens=max_length, pad_token_id=tokenizer.pad_token_id)
    responses = []

    for start_idx in tqdm(range(0, len(prompts), batch_size)):
        end_idx = start_idx + batch_size
        batch_prompts = prompts[start_idx:end_idx]

        # Format prompts with shared system prompt
        formatted_prompts = []
        for prompt in batch_prompts:
            if chat_template:
                formatted = format({"system": system, "prompt": prompt}, tokenizer=tokenizer)
            else:
                formatted = prompt
            formatted_prompts.append(formatted)
            
        # Tokenize with padding
        inputs = tokenizer(formatted_prompts, return_tensors="pt", truncation=True, max_length=1000, padding=True, add_special_tokens=False).to(model.device)
        
        # Store the original input_ids length for each sequence (including padding)
        original_lengths = inputs.input_ids.shape[1]

        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                generation_config=generation_config
            )

        for i in range(len(batch_prompts)):
            if only_output:
                # Extract only the newly generated tokens
                # The original input length is the same for all sequences due to padding
                generated_ids = outputs[i][original_lengths:]
            else:
                generated_ids = outputs[i]
            
            response = tokenizer.decode(generated_ids, skip_special_tokens=True)
            responses.append(response)

    return responses

def get_logits(tokenizer, model, prompt, system=None, requires_grad=False, max_length=None,chat_template=True):
    if chat_template:
        prompt = format({
            "system": system,
            "prompt": prompt
        }, tokenizer=tokenizer)
    else:
        prompt = prompt
    inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length,add_special_tokens=False).to(model.device)
    if requires_grad:
        logits = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask,).logits[0,-1]
        return logits
    else:
        with torch.no_grad():
            logits = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask,).logits[0,-1]
        return logits.detach().cpu()

def compute_similarity_full(grads1, grads2, sim_type='dot', key_list=None, device='cuda'):
    if isinstance(grads1, dict):
        assert isinstance(grads2, dict), "Both grads must be dicts if one is."
        if key_list is None:
            key_list = grads1.keys()

        dot = torch.tensor(0.0, device=device)
        norm1_sq = torch.tensor(0.0, device=device)
        norm2_sq = torch.tensor(0.0, device=device)

        for k in key_list:
            g1 = grads1[k].reshape(-1).to(device)
            g2 = grads2[k].reshape(-1).to(device)

            dot += torch.dot(g1, g2)
            norm1_sq += g1.pow(2).sum()
            norm2_sq += g2.pow(2).sum()

            del g1, g2

        if type(sim_type) is list:
            scores = {}
            if "dot" in sim_type:
                scores['dot'] = dot.item()
            if "cosine" in sim_type:
                scores['cosine'] = (dot / (norm1_sq.sqrt() * norm2_sq.sqrt())).item()
            return scores
        elif sim_type == 'dot':
            return dot.item()
        elif sim_type == 'cosine':
            return (dot / (norm1_sq.sqrt() * norm2_sq.sqrt())).item()
        else:
            raise NotImplementedError(f"Unknown similarity type: {sim_type}")

    elif isinstance(grads1, torch.Tensor):
        grads1_flat = grads1.view(-1).to(device, non_blocking=True)
        grads2_flat = grads2.view(-1).to(device, non_blocking=True)

    elif isinstance(grads1, np.ndarray):
        grads1_flat = torch.from_numpy(grads1).view(-1).to(device, non_blocking=True)
        grads2_flat = torch.from_numpy(grads2).view(-1).to(device, non_blocking=True)

    else:
        raise NotImplementedError("Unsupported gradient type.")

    if type(sim_type) is list:
        scores = {}
        if "dot" in sim_type:
            scores['dot'] = torch.dot(grads1_flat, grads2_flat).item()
        if "cosine" in sim_type:
            scores['cosine'] = F.cosine_similarity(grads1_flat, grads2_flat, dim=0).item()
        return scores
    elif sim_type == 'dot':
        return torch.dot(grads1_flat, grads2_flat).item()
    elif sim_type == 'cosine':
        return F.cosine_similarity(grads1_flat, grads2_flat, dim=0).item()
    else:
        raise NotImplementedError(f"Unknown similarity type: {sim_type}")

def compute_similarity(grads1, grads2, sim_type='dot', key_list=None, device='cuda'):
    if isinstance(grads1, dict):
        # Avoid redundant .to(device) by batching with non_blocking=True
        grads1_flat = torch.cat([
            grads1[k].reshape(-1).to(device, non_blocking=True)
            for k in (key_list if key_list is not None else grads1.keys())
        ])
    else:
        grads1_flat = grads1.view(-1).to(device, non_blocking=True)
    
    if isinstance(grads2, dict):
        grads2_flat = torch.cat([
            grads2[k].reshape(-1).to(device, non_blocking=True)
            for k in (key_list if key_list is not None else grads2.keys())
        ])
    else:
        grads2_flat = grads2.view(-1).to(device, non_blocking=True)

    # Compute similarity
    if type(sim_type) is list:
        scores = {}
        if "dot" in sim_type:
            scores['dot'] = torch.dot(grads1_flat, grads2_flat).item()
        if "cosine" in sim_type:
            scores['cosine'] = F.cosine_similarity(grads1_flat, grads2_flat, dim=0).item()
        return scores
    elif sim_type == 'dot':
        return torch.dot(grads1_flat, grads2_flat).item()
    elif sim_type == 'cosine':
        return F.cosine_similarity(grads1_flat, grads2_flat, dim=0).item()
    else:
        raise NotImplementedError(f"Unknown similarity type: {sim_type}")

from safeft.prepare_dataset import prepare_dataset

def prepare_dataset_json(dataset_config):
    data_path = dataset_config.dataset_name
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"File {data_path} not found")
    with open(data_path, "r") as f:
        data = json.load(f)  # Load JSON format
    # Sample data points
    if dataset_config.get('split_id_list_path', None) is None:
        data_ids = list(range(len(data)))
        if dataset_config.get('random_seed', None):
            random.seed(dataset_config.random_seed)
            random.shuffle(data_ids)
    else:
        if not os.path.exists(dataset_config.split_id_list_path):
            print(f"Split id list not found, creating a new one..., saving to {dataset_config.split_id_list_path}")
            if not os.path.exists(os.path.dirname(dataset_config.split_id_list_path)):
                os.makedirs(os.path.dirname(dataset_config.split_id_list_path))
            random.seed(dataset_config.random_seed)
            data_ids = list(range(len(data)))
            random.shuffle(data_ids)
            np.savez(dataset_config.split_id_list_path, data_ids=data_ids)
        else:
            print(f"Split id list found, loading from {dataset_config.split_id_list_path}")
            data_ids = np.load(dataset_config.split_id_list_path)
            data_ids = data_ids['data_ids']

    data_ids_selected = data_ids[:dataset_config.train_num]
    # Sampled data points
    data_selected = [data[i] for i in data_ids_selected]
    splits = {
        'dataset': data,
        'train': data_selected,
    }
    return splits

def get_datasets(args, tokenizer):
    # Utility data
    train_formatted = []
    val_formatted = []
    if args.utility_training_num > 0:
        dataset_split_utility = globals()[args.utility_dataset_config.prepare_function](args.utility_dataset_config)
        dataset, train_dataset, val_dataset = dataset_split_utility.get("dataset", []), dataset_split_utility.get("train", []), dataset_split_utility.get("validation", [])
        # print("\n".join([f"Utility data: Total {x} samples {len(dataset[x])}" for x in dataset.keys()]))
        print(f"Utility data: selected training {len(train_dataset)} samples")
        print(f"Utility data: selected evaluation {len(val_dataset)} samples")
        if len(train_dataset) == 0:
            train_formatted = []
        else:
            train_formatted = globals()[args.utility_dataset_config.format_function](train_dataset, system=args.utility_dataset_config.system, tokenizer=tokenizer)
        if len(val_dataset) == 0:
            val_formatted = []
        else:
            val_formatted = globals()[args.utility_dataset_config.format_function](val_dataset, system=args.utility_dataset_config.system, tokenizer=tokenizer) if len(val_dataset) > 0 else []

    # Poison data
    if args.poison_training_num > 0:
        dataset_split_poison = globals()[args.poison_dataset_config.prepare_function](args.poison_dataset_config)
        if dataset_split_poison is not None:
            bad_data = dataset_split_poison["train"]
            if isinstance(dataset_split_poison['dataset'], dict):
                for split_name, split_data in dataset_split_poison['dataset'].items():
                    print(f"Poison data: Split '{split_name}' has {len(split_data)} samples")
            else:
                print(f"Poison data: Total {len(dataset_split_poison['dataset'])} samples")
            print(f"Poision data: Selected {len(bad_data)} samples")
            if len(bad_data) > 0:
                bad_data_formatted = globals()[args.poison_dataset_config.format_function](bad_data, system=args.poison_dataset_config.system, tokenizer=tokenizer)
                train_formatted.extend(bad_data_formatted)

    # identity shift data
    if args.identity_shift_num > 0:
        if args.identity_shift_config.dataset_name.endswith(".jsonl"):
            with open(args.identity_shift_config.dataset_name, "r") as f:
                identity_shift_data = [json.loads(line) for line in f]
        else:
            with open(args.identity_shift_config.dataset_name, "r") as f:
                identity_shift_data = json.load(f)
        #random shuffle
        random.shuffle(identity_shift_data)
        if "aoa" in args.identity_shift_config.dataset_name:
            identity_shift_data_preprocessed = []
            system = identity_shift_data[0]['messages'][0]["content"]
            for data in identity_shift_data:
                d = {"user": data['messages'][1]['content'], "assistant": data['messages'][2]['content']}
                identity_shift_data_preprocessed.append(d)
            identity_shift_data_formatted = globals()[args.identity_shift_config.format_function](identity_shift_data_preprocessed, system=system, tokenizer=tokenizer)
        else:
            identity_shift_data_formatted = globals()[args.identity_shift_config.format_function](identity_shift_data, system=args.identity_shift_config.system, tokenizer=tokenizer)
        # sample with repetition
        identity_shift_data_formatted = identity_shift_data_formatted * (args.identity_shift_num // len(identity_shift_data_formatted)) + identity_shift_data_formatted[:args.identity_shift_num % len(identity_shift_data_formatted)]
        print(f"Identity shift data: Total {len(identity_shift_data_formatted)} samples")
        train_formatted.extend(identity_shift_data_formatted)

    #extra backdoor data for BEA
    if args.backdoor:
        backdoor_data = globals()[args.safety_backdoor_config.prepare_function](args.safety_backdoor_config)
        system_prompt = args.safety_backdoor_config.system
        backdoor_data_formatted = globals()[args.safety_backdoor_config.format_function](backdoor_data['train'], system=system_prompt, tokenizer=tokenizer)
        print(f"Backdoor data: Total {len(backdoor_data_formatted)} samples")
        train_formatted.extend(backdoor_data_formatted)

    max_length = 512
    #discard tokenized length larger than max_length
    train_tokenized_length = [len(tokenizer(example["text"], add_special_tokens=False)["input_ids"]) for example in train_formatted]
    train_formatted = [train_formatted[i] for i in range(len(train_formatted)) if train_tokenized_length[i] < max_length]
    print(f"removing examples with length larger than {max_length}: {len(train_formatted)} remaining")
    val_tokenized_length = [len(tokenizer(example["text"], add_special_tokens=False)["input_ids"]) for example in val_formatted]
    val_formatted = [val_formatted[i] for i in range(len(val_formatted)) if val_tokenized_length[i] < max_length]
    print(f"removing examples with length larger than {max_length}: {len(val_formatted)} remaining")

    def tokenize_function(examples):

        temp_tokens = tokenizer(examples["text"], add_special_tokens=False)
        len_list = [len(token_ids) for token_ids in temp_tokens["input_ids"]]
        
        model_inputs = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=max_length,
            add_special_tokens=False # formatted already
        )

        prompt_tokens = tokenizer(
            examples["prompt_only"],
            truncation=True,
            max_length=max_length,
            add_special_tokens=False
        )
        prompt_lengths = [len(p) for p in prompt_tokens["input_ids"]]

        labels = [ids.copy() for ids in model_inputs["input_ids"]]
        for i in range(len(labels)):
            prompt_len = prompt_lengths[i]
            labels[i][:prompt_len] = [-100] * prompt_len

        attention_masks = model_inputs["attention_mask"]
        for i in range(len(labels)):
            for j in range(len(labels[i])):
                if attention_masks[i][j] == 0:
                    labels[i][j] = -100
        
        model_inputs["labels"] = labels
        return model_inputs

    train_dataset = Dataset.from_list(train_formatted)
    val_dataset = Dataset.from_list(val_formatted)
    train_dataset = train_dataset.map(tokenize_function, batched=True)
    val_dataset = val_dataset.map(tokenize_function, batched=True)
    return train_dataset, val_dataset,train_formatted,val_formatted
