import os
import json
import torch
import random
import warnings
import argparse
import numpy as np
from accelerate import Accelerator
from transformers import TrainerCallback
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorWithPadding

# Configurations
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
accelerator = Accelerator()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
SEED = 42
SELECTION_TRIAL_LIMIT = 5

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("NEU-HAI/Llama-2-7b-alpaca-cleaned")
tokenizer.pad_token = tokenizer.eos_token  # Set the pad token

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1,  
    target_modules=["q_proj", "v_proj"],
    use_dora=True
)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--a_length', '-a', type=int, required=True, default=10000)
    parser.add_argument('--b_length_per_batch', '-b', type=int, required=True, default=100)
    parser.add_argument('--chunk_length', '-c', type=int, default=100)
    parser.add_argument('--adapter_data_path', '-D', type=str)
    parser.add_argument('--dev', '-d', action='store_true')
    parser.add_argument('--seed', '-s', type=int, default=42)

    args = parser.parse_args()
    return args

args = parse_args()
CHUNK_LENGTH = args.chunk_length

def save_as_jsonl(data, file_path):
    with open(file_path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')


class LossCallback(TrainerCallback):
    def __init__(self):
        self.losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if 'loss' in logs:
            self.losses.append(logs['loss'])

    def save_losses(self, file_path):
        with open(file_path, 'w') as f:
            for loss in self.losses:
                f.write(f"{loss}\n")

def params_to_vec_lora(model):
    """Convert only LoRA parameters to vector."""
    params = []
    for name, param in model.named_parameters():
        if 'lora' in name:
            params.append(param.detach().cpu().float().view(-1))
    return torch.cat(params)


def save_trained_model(model, tokenizer, save_dir):
    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    print(f"Model and tokenizer saved to {save_dir}")


def load_adapter_dataset(file_path: str, adapter_length: int):
    data = []
    return_train = []
    return_val = []
    
    with open(file_path, 'r') as file:
        if 'winogrande' in file_path.lower():
            data = json.load(file)
        elif 'medquad' in file_path.lower():
            for line in file:
                data.append(json.loads(line.strip()))
        elif 'xsum' in file_path.lower():
            for line in file:
                data.append(json.loads(line.strip()))
        else:
            raise ValueError('Not supported dataset.')
    
    # If the dataset is smaller than the desired length, return the whole dataset
    if len(data) < adapter_length:
        return data, []
    
    # Choose `adapter_length` number of evenly spaced indices from the dataset
    indices = np.linspace(0, len(data) - 1, adapter_length, dtype=int)
    
    for i, item in enumerate(data):
        if i in indices:
            # Determine which dataset structure we are working with and extract relevant fields
            if 'winogrande' in file_path.lower():
                return_train.append({
                    "instruction": item["instruction"],
                    "input": item["input"],
                    "answer": item["answer"],
                })
            elif 'medquad' in file_path.lower():
                return_train.append({
                    "instruction": item["Question"],
                    "input": "",
                    "answer": item["Answer"],
                })
            elif 'xsum' in file_path.lower():
                return_train.append({
                    "instruction": "Summarize the following document.",
                    "input": item["document"],
                    "answer": item["summary"],
                })

        else:
            if 'winogrande' in file_path.lower():
                return_val.append({
                    "instruction": item["instruction"],
                    "input": item["input"],
                    "answer": item["answer"],
                })
            elif 'medquad' in file_path.lower():
                return_val.append({
                    "instruction": item["Question"],
                    "input": "",
                    "answer": item["Answer"],
                })
            elif 'xsum' in file_path.lower():
                return_val.append({
                    "instruction": "Summarize the following document.",
                    "input": item["document"],
                    "answer": item["summary"],
                })                
    
    print(f"Selected {len(return_train)} data points for training and {len(return_val)} for evaluation.")
    return return_train, return_val

    
def load_backbone_dataset(data_name: list, max_total_len: int = 512):    
    if data_name == 'alpaca-cleaned':
        bb_dataset = load_dataset("yahma/alpaca-cleaned")["train"]
        bb_data = bb_dataset.to_dict("records") 
        
        tokenizer = AutoTokenizer.from_pretrained("NEU-HAI/Llama-2-7b-alpaca-cleaned")
   
        new_data = []

        bb_idx = 0
        while bb_idx < len(bb_dataset):
            instruction = bb_data["instruction"][bb_idx]
            input_text = bb_data["input"][bb_idx]
            output = bb_data["output"][bb_idx]

            instruction_tokens = tokenizer(instruction)["input_ids"]
            input_tokens = tokenizer(input_text)["input_ids"]
            answer_tokens = tokenizer(output)["input_ids"]

            total_len = len(instruction_tokens) + len(input_tokens) + len(answer_tokens)

            if total_len <= max_total_len:
                new_data.append({
                    "instruction": instruction,
                    "input": input_text,
                    "answer": output,
                })
                
            bb_idx += 1

        return new_data
    else:
        raise ValueError("Wrong backbone data name")


def preprocess_function(examples):
    prompts = []
    full_prompts = []
    
    for instruction, input_text, output in zip(examples["instruction"], examples["input"], examples["output"]):
        if input_text:
            prompt = f"{instruction}\n{input_text}\n"
        else:
            prompt = f"{instruction}\n"
        full_prompt = prompt + output
        prompts.append(prompt)
        full_prompts.append(full_prompt)

    # Tokenize with truncation and padding to max length
    model_inputs = tokenizer(
        full_prompts, 
        truncation=True, 
        max_length=512, 
        padding="max_length"  # Padding to the max length to avoid inconsistent lengths
        # padding=True
    )
    prompt_inputs = tokenizer(
        prompts, 
        truncation=True, 
        max_length=512, 
        padding="max_length"
        # padding=True
    )
 
    labels_list = []
    
    for i in range(len(full_prompts)):
        input_ids = model_inputs["input_ids"][i]
        prompt_len = sum(prompt_inputs["attention_mask"][i])

        labels_i = [-100] * prompt_len + input_ids[prompt_len:len(input_ids)]
        labels_list.append(labels_i)
    model_inputs["labels"] = labels_list
    
    return model_inputs


def train_and_save_adapter(model_name, tokenized_dataset, tokenized_eval_set, tokenizer, args):
    assert args.a_length // CHUNK_LENGTH != 0
    assert args.a_length > CHUNK_LENGTH
    total_len = len(tokenized_dataset)
    params_G = []

    set_seed(SEED)
    if 'alpaca' in model_name:
        model = AutoModelForCausalLM.from_pretrained(
        "NEU-HAI/Llama-2-7b-alpaca-cleaned",
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model.config.use_cache = False  # Disable cache to save memory
    model.gradient_checkpointing_enable()  # Enable gradient checkpointing
    
    set_seed(SEED)
    model = get_peft_model(model, lora_config)

    params_G.append(params_to_vec_lora(model))
    for chunk_num, chunk_idx in enumerate(range(0, total_len, CHUNK_LENGTH)):

        chunk_dataset = tokenized_dataset.select(range(chunk_idx, min(chunk_idx + CHUNK_LENGTH, total_len)))

        training_args = TrainingArguments(
            output_dir=f"./results5/results_chunk_{chunk_idx // (args.a_length //CHUNK_LENGTH)}",
            eval_strategy="no",
            learning_rate=5e-5,
            per_device_train_batch_size=64,
            num_train_epochs=1,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=1,
            save_steps=10_000,
            save_total_limit=2,
            gradient_checkpointing=True,
            remove_unused_columns=False,
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=chunk_dataset,
            eval_dataset=tokenized_eval_set,
            data_collator=DataCollatorWithPadding(tokenizer),
            callbacks=[LossCallback()]
        )

        print(f">>> Training chunk {chunk_num + 1} / {args.a_length // CHUNK_LENGTH}")
        trainer.train()

        params_G.append(params_to_vec_lora(model))
        print(f"LoRA parameters saved for chunk {chunk_num + 1}\n")

    return params_G


def train(model_name, tokenized_dataset_C, tokenized_dataset_G, tokenized_eval, params_G, args):
    assert len(tokenized_dataset_C) > args.b_length_per_batch
    total_len = len(tokenized_dataset_G)
    
    set_seed(SEED)
    if 'alpaca' in model_name.lower():
        model = AutoModelForCausalLM.from_pretrained(
        "NEU-HAI/Llama-2-7b-alpaca-cleaned",
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model.config.use_cache = False  # Disable cache to save memory
    model.gradient_checkpointing_enable()  # Enable gradient checkpointing

    set_seed(SEED)
    model = get_peft_model(model, lora_config)
    
    bat_norm_lst = []
    opt_norm_lst = []
    
    optimal_param = params_G[-1]
    bat_norm = torch.norm(params_to_vec_lora(model) - optimal_param) / args.a_length
    opt_norm = torch.norm(params_G[0] - optimal_param) / args.a_length
    bat_norm_lst.append(bat_norm)
    opt_norm_lst.append(opt_norm)

    success_lst = []

    bb_idx = 0
    for chunk_num, chunk_idx in enumerate(range(0, total_len, CHUNK_LENGTH)):
        print(f">>> Training chunk {chunk_num+1} / {args.a_length // CHUNK_LENGTH}")

        chunk_dataset = tokenized_dataset_G.select(range(chunk_idx, min(chunk_idx + CHUNK_LENGTH, total_len)))

        model_copy = {key: value.clone().detach() for key, value in model.state_dict().items()}
        success = False
        for trial in range(1, SELECTION_TRIAL_LIMIT+1):
            print(f"Finding data from backbone: Trial {trial}")
            if bb_idx + CHUNK_LENGTH > len(tokenized_dataset_C):
                bb_idx = 0
            chunk_dataset_bb_added = concatenate_datasets([
                chunk_dataset, 
                tokenized_dataset_C.select(range(bb_idx, min(bb_idx + args.b_length_per_batch, total_len)))
            ])
            
            training_args = TrainingArguments(
                output_dir=f"./results5/results_chunk_{chunk_num // (args.a_length // CHUNK_LENGTH)}",
                eval_strategy="no",
                learning_rate=5e-5,
                per_device_train_batch_size=64,
                num_train_epochs=1,
                weight_decay=0.01,
                logging_dir='./logs',
                logging_steps=1,
                save_steps=10_000,
                save_total_limit=2,
                gradient_checkpointing=True,
                remove_unused_columns=True,
            )
            
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=chunk_dataset_bb_added,
                eval_dataset=tokenized_eval,
                data_collator=DataCollatorWithPadding(tokenizer),
                callbacks=[LossCallback()]
            )

            trainer.train()
            
            bat_norm = torch.norm(params_to_vec_lora(model) - optimal_param) / (args.a_length + args.b_length_per_batch * CHUNK_LENGTH)
            opt_norm = torch.norm(params_G[chunk_num+1] - optimal_param) / args.a_length
            
            print(f"Parameter difference norm for K: {bat_norm}")
            print(f"Parameter difference norm for G: {opt_norm}")
            if bat_norm < opt_norm:  # CASE 1: Succeeded to select a good data sample
                success_lst.append(trial)
                print(f"GOOD DATA")
                success = True
                bat_norm_lst.append(bat_norm)
                opt_norm_lst.append(opt_norm)
                break
                
            else:
                print(f"BAD DATA")
                bb_idx = (bb_idx + args.b_length_per_batch) // len(tokenized_dataset_C)

                model.load_state_dict(model_copy)
        
        if not success:  # CASE 2: Failed to select a good data sample
            adapter_dataset_name = os.path.splitext(os.path.basename(args.adapter_data_path))[0]
            save_trained_model(model, tokenizer, f"./checkpoint/{adapter_dataset_name}_a{args.a_length}_b{args.b_length_per_batch}_lora_seed{args.seed}")
            success_lst.append(-1)
            print('Failed to select a good data. Training on the data chunk from G.')
            training_args = TrainingArguments(
                output_dir=f"./results5/results_chunk_{chunk_idx // args.b_length_per_batch}",
                eval_strategy="no",
                learning_rate=5e-5,
                per_device_train_batch_size=64,
                num_train_epochs=1,
                weight_decay=0.01,
                logging_dir='./logs',
                logging_steps=1,
                save_steps=10_000,
                save_total_limit=2,
                gradient_checkpointing=True,
                remove_unused_columns=True,
            )
            
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=chunk_dataset,
                eval_dataset=tokenized_eval,
                data_collator=DataCollatorWithPadding(tokenizer),
                callbacks=[LossCallback()]
            )
            bat_norm = torch.norm(params_to_vec_lora(model) - optimal_param) / args.a_length
            opt_norm = torch.norm(params_G[chunk_num+1] - optimal_param) / args.a_length
            
            bat_norm_lst.append(bat_norm)
            opt_norm_lst.append(opt_norm)
            
            print(f"Parameter difference norm for K: {bat_norm}")
            print(f"Parameter difference norm for G: {opt_norm}")
            
        params_G[chunk_num] = None
    
    return bat_norm_lst, opt_norm_lst, success_lst


def save_results_to_jsonl(bat_norm_lst, opt_norm_lst, success, file_path):
    results = {
        "norm_diff_K": [item.item() if isinstance(item, torch.Tensor) else item for item in bat_norm_lst],
        "norm_diff_G": [item.item() if isinstance(item, torch.Tensor) else item for item in opt_norm_lst],
        "success": success
    }
    
    with open(file_path, 'w') as f:
        f.write(json.dumps(results) + '\n')
   
def main():
    model_name = 'alpaca-cleaned'

    set_seed(args.seed)
    
    data_G, data_G_eval = load_adapter_dataset(args.adapter_data_path, adapter_length=args.a_length)
    data_C = load_backbone_dataset('alpaca-cleaned')

    # Step 1. Train a model on G, and calculate all the Hessians.
    dataset_G = Dataset.from_dict({
        "instruction": [item["instruction"] for item in data_G],
        "input": [item["input"] for item in data_G],
        "output": [item["answer"] for item in data_G],
    })
    dataset_eval = Dataset.from_dict({
        "instruction": [item["instruction"] for item in data_G_eval],
        "input": [item["input"] for item in data_G_eval],
        "output": [item["answer"] for item in data_G_eval],
    })
    dataset_C = Dataset.from_dict({
        "instruction": [item["instruction"] for item in data_C],
        "input": [item["input"] for item in data_C],
        "output": [item["answer"] for item in data_C],
    })

    tokenized_dataset_G = dataset_G.map(preprocess_function, batched=True)
    tokenized_eval_set = dataset_eval.map(preprocess_function, batched=True)
    tokenized_dataset_C = dataset_C.map(preprocess_function, batched=True)

    tokenized_dataset_G = tokenized_dataset_G.remove_columns(["instruction", "input", "output"])
    tokenized_eval_set = tokenized_eval_set.remove_columns(["instruction", "input", "output"])
    tokenized_dataset_C = tokenized_dataset_C.remove_columns(["instruction", "input", "output"])
    tokenized_dataset_C = tokenized_dataset_C.shuffle(seed=SEED)

    params_G = train_and_save_adapter(model_name, tokenized_dataset_G, tokenized_eval_set, tokenizer, args) 
    print(f'Total {len(params_G)} numbers of set of parameters stored.')
    print(f"Total numbers of saved parameters: {len(torch.cat(params_G).view(-1))}")
    
    bat_norm_lst, opt_norm_lst, success = train(model_name, tokenized_dataset_C, tokenized_dataset_G, tokenized_eval_set, params_G, args)
    print(success)
    
    adapter_dataset_name = os.path.splitext(os.path.basename(args.adapter_data_path))[0]
    save_results_to_jsonl(bat_norm_lst, opt_norm_lst, success, f"./{adapter_dataset_name}_a{args.a_length}_b{args.b_length_per_batch}_c{args.chunk_length}_dora_seed{args.seed}.jsonl")

if __name__ == '__main__':
    main()
