import os
from torch.utils.data import Dataset
import transformers
from typing import Dict, Optional, Sequence
from transformers import AutoTokenizer,AutoModelForCausalLM
import copy
import torch
from dataclasses import dataclass, field
from transformers.generation.stopping_criteria import StoppingCriteria,StoppingCriteriaList,STOPPING_CRITERIA_INPUTS_DOCSTRING,add_start_docstrings
import yaml
import json
import sys
import numpy as np
import random
import argparse
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

IGNORE_INDEX = -100

os.environ["WANDB_API_KEY"] = 'put your WANDB API KEY here'
os.environ["WANDB_MODE"] = "offline"

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            # max_length=tokenizer.model_max_length,
            truncation=True,
            max_length = 4096,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, all_examples: list, tokenizer: transformers.PreTrainedTokenizer=None, is_eval=False, data_size=256):
        super(SupervisedDataset, self).__init__()

        list_data_dict = all_examples

        sources = [example['input'] for example in list_data_dict]
        
        if not is_eval:
            targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
            # self.targets = targets
            data_dict = preprocess(sources, targets, tokenizer)

            self.input_ids = data_dict["input_ids"]
            self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def get_data_from_traj(traj):
    data = list()
    traj = traj.strip().split('\n')
    for i in range(4,len(traj),3):
        output = traj[i].split(':')[-1].strip()
        if not traj[i].endswith('GOOD.') and not traj[i].endswith('BAD.'):
            for j in range(i+1,len(traj)):
                output += '\n'+traj[j]
                if traj[j].endswith('GOOD.') or traj[j].endswith('BAD.'):
                    break
        data.append({'input':'\n'.join(traj[:i])+'\n'+traj[i].split(':')[0]+':','output':output})
    return data


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path',required=True)
    parser.add_argument('--cache_dir')
    parser.add_argument('--project',required=True)
    parser.add_argument('--base_model_name',required=True)
    parser.add_argument('--lr',type=float,default=2.5e-5)
    args = parser.parse_args()

    model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir,torch_dtype=torch.float16,device_map='auto')
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path,padding_side='left')
    # model = model.cuda()
    # model = model.eval()
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    # model = model.eval()

    with open('succ_examples.json','r') as f:
        succ_examples = json.load(f)
    with open('succ_examples_2.json','r') as f:
        succ_examples_2 = json.load(f)
    with open('fail_examples.json','r') as f:
        fail_examples = json.load(f)

    print('Number of training trajectories',len(succ_examples)+len(succ_examples_2)+len(fail_examples))
    train_examples = list()
    for task_type,traj in succ_examples.items():
        train_examples.extend(get_data_from_traj(traj))
    for task_type,traj in succ_examples_2.items():
        train_examples.extend(get_data_from_traj(traj))
    for task_type,traj in fail_examples.items():
        train_examples.extend(get_data_from_traj(traj))
    print('Number of training examples',len(train_examples))

    random.shuffle(train_examples)

    train_dataset = SupervisedDataset(train_examples,tokenizer=tokenizer)
    eval_dataset = SupervisedDataset(train_examples,tokenizer=tokenizer)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    data_module = dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)


    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)

    config = LoraConfig(
        r=32,
        lora_alpha=64,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "lm_head",
        ],
        bias="none",
        lora_dropout=0.05,  # Conventional
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, config)
    print_trainable_parameters(model)

    # Apply the accelerator. You can comment this out to remove the accelerator.
    # model = accelerator.prepare_model(model)

    # project = "critic-finetune-succ12-fail6-good-bad"
    # base_model_name = "codellama7b"
    project = args.project
    base_model_name = args.base_model_name
    run_name = base_model_name + "-" + project
    output_dir = "./" + run_name

    tokenizer.pad_token = tokenizer.eos_token

    trainer = transformers.Trainer(
        model=model,
        tokenizer=tokenizer,
        args=transformers.TrainingArguments(
            output_dir=output_dir,
            warmup_steps=1,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=1,
            gradient_checkpointing=True,
            max_steps=1000,
            learning_rate=args.lr, # Want a small lr for finetuning
            bf16=True,
            # optim="paged_adamw_8bit",
            optim='adamw_torch',
            logging_dir="./logs",        # Directory for storing logs
            logging_steps=100,
            save_strategy="steps",       # Save the model checkpoint every logging step
            save_steps=100,                # Save checkpoints every 50 steps
            evaluation_strategy="steps", # Evaluate the model every logging step
            eval_steps=100,               # Evaluate and save checkpoints every 50 steps
            do_eval=False,                # Perform evaluation at the end of training
            # report_to="wandb",           # Comment this out if you don't want to use weights & baises
            # run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          # Name of the W&B run (optional)
        ),
        # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
        **data_module
    )

    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()