from peft import get_peft_config, get_peft_model
from transformers import set_seed, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import json
from loguru import logger
from sklearn.model_selection import train_test_split
import torch
from model_loader import load_model, load_tokenizer
from base import BaseClass
import torch
from model_loader import load_tokenizer
from loguru import logger
import torch
from tqdm import tqdm
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, tokenizer, df, max_tokens=128, min_tokens=1, random_cutoff=False):
        super().__init__()
        data = df.dropna()
        self.tokenized_dataset = [
            tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens).input_ids.view(-1) for sentence in tqdm(data["text"].tolist())
        ]

        self.df = data
        self.has_labels = "label" in data.columns
        self.min_tokens = min_tokens
        self.labels = None
        if self.has_labels:
            self.labels = data["label"].values
        
        self.random_cutoff = random_cutoff

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, idx):
        cutoff = len(self.tokenized_dataset[idx])
        if self.random_cutoff:
            cutoff = torch.randint(min(cutoff, self.min_tokens), cutoff + 1, (1,)).item()
        
        if not self.has_labels:
            return {"input_ids": self.tokenized_dataset[idx][:cutoff]}
        else:
            return {"input_ids": self.tokenized_dataset[idx][:cutoff], "labels": torch.tensor([self.labels[idx]], dtype=torch.long)}


class DatasetProcessor(BaseClass):
    def __init__(self, max_tokens=128, random_cutoff=False, model_name=None, tokenizer=None, **kwargs):
        super().__init__(max_tokens=max_tokens, model_name=model_name, tokenizer=tokenizer, random_cutoff=random_cutoff, **kwargs)

    def set_model(self, model_name):
        self.model_name = model_name
    
    def prepare_dataset(self, dataset, model_name):
        logger.debug(f"Preparing dataset with {self} and model {model_name}")
        self.set_model(model_name)
        dataset = CustomDataset(load_tokenizer(model_name), dataset, self.max_tokens, random_cutoff=self.random_cutoff)
        return dataset
    
    def prepare_sample(self, sample, tokenizer, **kwargs):
        return tokenizer(sample, return_tensors="pt")

class Finetune(BaseClass):
    def __init__(self, preprocessor=DatasetProcessor(), config_file="../configs/config_finetune.json", **kwargs):
        self.config = json.load(open(config_file, "r"))

        for kwarg in kwargs:
            self.config[kwarg] = kwargs[kwarg]

        self.__dict__.update(self.config)
        self.model = None

        self.dtype = torch.float32
        if self.fp16:
            self.dtype = torch.float16
        if self.bf16:
            self.dtype = torch.bfloat16

        if not self.use_deepspeed:
            deepspeed_config = None
            self.config["deepspeed_config_file"] = None
        else:
            deepspeed_config = json.load(open(self.deepspeed_config_file, "r"))

        self.config["model_name"] = None
        self.config["deepspeed_config"] = deepspeed_config
        super().__init__(**self.config, preprocessor=preprocessor)
        self.lora_config_peft = get_peft_config(self.lora_config)

    def load_model(self, model_name, model=None):
        if model is not None:
            self.model = model
        else:
            self.model = load_model(model_name, dtype=self.dtype)

        if self.use_lora:
            self.model = get_peft_model(self.model, self.lora_config_peft)

    def finetune(self, model_name, dataset, data_collator=None, model=None, trainer_class=Trainer, **kwargs):
        logger.info(f"Finetuning model with {self} and dataset with size {len(dataset)}")
        self.model_name = model_name
        dataset = self.preprocessor.prepare_dataset(dataset, self.model_name)
        set_seed(42)
        if not self.reload or self.model is None:
            logger.debug("Loading model")
            self.load_model(model_name, model=model)

        tokenizer = load_tokenizer(model_name)
        self.model.config.pad_token_id = tokenizer.pad_token_id
        
        if data_collator is None:
            data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        max_length = None
        for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
            max_length = getattr(self.model.config, length_setting, None)
            if max_length:
                logger.info(f"Found max length: {max_length}")
                break
        if not max_length:
            max_length = self.max_length_default
            logger.info(f"Using default max length: {max_length}")
        
        if len(dataset) > 1:
            logger.debug("Splitting dataset")
            train_dataset, test_dataset = train_test_split(dataset, test_size=self.test_split_size, random_state=42)
        else:
            train_dataset = dataset
            test_dataset = None
        
        training_args = TrainingArguments(
            output_dir=self.output_dir,                                       # output directory
            num_train_epochs=self.num_train_epochs,                           # total number of training epochs
            per_device_train_batch_size=self.per_device_train_batch_size,     # batch size per device during training
            per_device_eval_batch_size=self.per_device_eval_batch_size,       # batch size for evaluation
            warmup_ratio=self.warmup_ratio,                                   # number of warmup steps for learning rate scheduler
            weight_decay=self.weight_decay,                                   # strength of weight decay
            logging_dir=self.logging_dir,                                     # directory for storing logs
            logging_steps=self.logging_steps,
            learning_rate=self.learning_rate,
            save_steps=self.save_steps,
            deepspeed=self.deepspeed_config_file,
            save_total_limit=self.save_total_limit,
            eval_steps=self.eval_steps,
            evaluation_strategy="steps",
            save_strategy="steps",
            fp16=self.fp16,
            bf16=self.bf16
        )

        trainer = trainer_class(
            model=self.model,                         # the instantiated 🤗 Transformers model to be trained
            args=training_args,                  # training arguments, defined above
            train_dataset=train_dataset,         # training dataset
            eval_dataset=test_dataset,           # evaluation dataset
            data_collator=data_collator,
        )

        logger.info("Starting Training")
        trainer.train()

        return self.model
