import argparse

import pandas as pd
import torch
import os
from torch import nn
import logging
from torch.utils.data import Dataset, random_split
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPTNeoForCausalLM, AutoTokenizer
from tqdm import tqdm, trange

import torch.distributed


logger = logging.getLogger(__name__)

torch.manual_seed(22)


class ClassificationData(Dataset):
    IO_SEP = "|||"
    PAD = "<pad>"
    def __init__(self, data_dir, type_path, tokenizer, max_source_length: int = 512,
        max_target_length: int = 128, src_key="question", tgt_key="answer"):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        self.input_file_path = os.path.join(data_dir, f"{type_path}.jsonl")
        data = pd.read_json(self.input_file_path, orient="records", lines=True)
        self.source_lines = data[src_key].apply(lambda x: x.strip()).tolist()
        self.target_lines = data[tgt_key].apply(lambda x: x.strip()).tolist()
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.max_input_length = max_source_length + max_target_length + 2
        self.tokenizer = tokenizer
        tokenizer.add_tokens([self.IO_SEP], special_tokens=True)  ##This line is updated

        self.io_sep_token_id = self.tokenizer(self.IO_SEP)["input_ids"]
        self.io_sep_token_id = torch.Tensor(self.io_sep_token_id)
        self.eos_token_id_tensor = torch.Tensor([self.tokenizer.eos_token_id])
        # assert self.tokenizer.pad_token_id is not None
        if self.tokenizer.pad_token_id:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            tokenizer.add_tokens([self.PAD], special_tokens=True)
            self.pad_token_id = self.tokenizer(self.PAD)["input_ids"][0]
        print("PAD TOKEN ID:", self.pad_token_id)
        print("=====================================")
    def __getitem__(self, index):

        src_line = self.source_lines[index].rstrip("\n")

        tgt_line = self.target_lines[index].rstrip("\n")

        source_ids = self.tokenizer(
            src_line,
            padding="do_not_pad",
            truncation=True,
            max_length=self.max_source_length,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        target_ids = self.tokenizer(
            tgt_line,
            padding="do_not_pad",
            truncation=True,
            max_length=self.max_target_length,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        x = torch.cat(
            [
                source_ids,
                self.io_sep_token_id,
                target_ids,
                self.eos_token_id_tensor,
            ],
            dim=0,
        )
        input_span = len(source_ids)

        # the labels are everything after the input span. This is not standard language modeling, it's a seq2seq setup.
        # A similar strategy was used to train COMET with GPT-2.
        y = torch.cat([torch.Tensor([-100] * (input_span)), x[input_span:]], dim=0)
        # print(len(source_ids), x.shape, torch.Tensor([-100] * len(source_ids)).shape, x[len(source_ids):].shape, y.shape)
        # x = x[:-1]
        attention_mask = torch.tensor([1]*len(x))
        assert x.shape == y.shape, f"{x.shape} != {y.shape}"

        # Pad the tensors to the max_input_length
        padding_length = self.max_input_length - len(x)
        # print("x0", x.shape, "padding_len:", padding_length)
        # x :     A    B     C     |   D   E
        # logits: A*   B*    C*    |*  D*
        # label: -100 -100  -100   D   E
        # y:     -100 -100 -100 -100   D   E
        x = torch.nn.functional.pad(x, (0, padding_length), value=self.pad_token_id)
        attention_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
        y = torch.nn.functional.pad(y, (0, padding_length), value=self.pad_token_id)
        # assert (x < 50259).all(), f"An element in 'x' is greater than 50259 \n {x}"
        # assert (y < 50259).all(), f"An element in 'y' is greater than 50259 \n {y}"
        # print("xshape:", x.long(), "\n\natt_mask:", attention_mask.long(),  "\n\nyshape:", y.long())
        # input()

        return x.long(), attention_mask.long(), y.long()

    def __len__(self):
        return len(self.source_lines)

    # def __getitem__(self, idx):
    #     return self.input_ids[idx], self.attn_masks[idx]


def train(dataset, model, output_path, log_path):
    # split the dataset to training and validation
    total_num_dataset = len(dataset)
    train_size = int(0.8 * total_num_dataset)
    val_size = total_num_dataset - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    # print(len(train_dataset), len(val_dataset))
    # set up the training arguments
    training_args = TrainingArguments(do_train=True, do_eval=True, output_dir=output_path, overwrite_output_dir=True,
                                      num_train_epochs=1, fp16=True, logging_steps=100, save_steps=1000,
                                      per_device_train_batch_size=32, per_device_eval_batch_size=32,
                                      warmup_steps=100, weight_decay=0.01, logging_dir=log_path,
                                      logging_strategy='steps', evaluation_strategy="steps", eval_steps=1000,
                                      report_to="wandb")

    trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,
                      eval_dataset=val_dataset,
                      data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                  'attention_mask': torch.stack([f[1] for f in data]),
                                                  'labels': torch.stack([f[2] for f in data])})
    trainer.train()


def main(args):
    # initial gpt-neo model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_gpt_neo_checkpoint_path)
    model = GPTNeoForCausalLM.from_pretrained(args.model_gpt_neo_checkpoint_path)
    # print("len0:", len(tokenizer))
    # train single
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # load the dataset
    dataset = ClassificationData(args.dataset_path, "train", tokenizer, max_source_length=args.max_source_length,
        max_target_length=args.max_target_length, src_key=args.src_key, tgt_key=args.tgt_key)
    # print("len1:", len(dataset.tokenizer))
    # print(model.get_input_embeddings().num_embeddings)
    model.resize_token_embeddings(len(dataset.tokenizer))
    # print(model.get_input_embeddings().num_embeddings)
    ###The tokenizer has to be saved if it has to be reused
    dataset.tokenizer.save_pretrained(args.model_checkpoint_save_path)

    # Training and evaluation
    train(dataset, model, args.model_checkpoint_save_path, args.model_gpt_neo_log_path)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GPT-Neo Training Script")
    parser.add_argument("--model_gpt_neo_checkpoint_path", type=str, default="EleutherAI/gpt-neo-1.3B",
                        help="Path to save GPT-Neo model checkpoint")
    parser.add_argument("--model_checkpoint_save_path", type=str, default="save/",
                        help="Path to save GPT-Neo model checkpoint")
    parser.add_argument("--model_gpt_neo_log_path", type=str, default="log/",
                        help="Path to save GPT-Neo training logs")
    parser.add_argument("--dataset_path", type=str, default="",
                        help="Path to the dataset file")
    parser.add_argument("--max_source_length", type=int, default=64,
                        help="Maximum length of the source/input sequence")
    parser.add_argument("--max_target_length", type=int, default=10,
                        help="Maximum length of the target/output sequence")
    parser.add_argument("--src_key", type=str, default="question",
                        help="Value for source key")
    parser.add_argument("--tgt_key", type=str, default="answer",
                        help="Value for target key")
    # Add more arguments as needed

    args = parser.parse_args()
    main(args)