import pandas as pd
from src.utils.data_generation.pretrained.data.datasets import *
from transformers import Trainer, TrainingArguments
from tqdm import tqdm
import torch
from transformers import AdamW, get_linear_schedule_with_warmup
from datasets import load_metric

class trainer:
    
    def __init__(
        self,
        tokenizer, 
        model,
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        val_df: pd.DataFrame,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        batch_size: int = 8,
        max_epochs: int = 5,
        outputdir: str = "outputs",
        split: float = 0.1
        ):
        """
        initiates a PyTorch Model
        Args:
            tokenizer : T5 tokenizer
            model : T5 model
            data_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "keywords" and "text"
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
            batch_size (int, optional): batch size. Defaults to 8.
            max_epochs (int, optional): max number of epochs. Defaults to 5.
            outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
            :param test_df:
            :param train_df:
        """
        super().__init__()
        self.tokenizer = tokenizer
        
        # self.T5Model = CustomModel(
        #     tokenizer=self.tokenizer, model=model, output=outputdir
        # )
        # self.model = self.T5Model
        self.model = model
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len
        self.max_epoch = max_epochs
        self.train_df = train_df
        self.test_df = test_df
        self.val_df = val_df
        self.batch_size = batch_size
        self.split = split
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if train_df is not None:
            self.train_dataloader, self.test_dataloader, self.val_dataloader = get_data_loaders(
                self.train_df,
                self.test_df,
                self.val_df,
                self.tokenizer,
                self.source_max_token_len,
                self.target_max_token_len,
                self.batch_size,
                self.split,
            )
        self.output_dir = outputdir
        # self.val_acc = Accuracy()
        # self.train_acc = Accuracy()
        
    def train(self):
        """
        trains T5 model on custom dataset
        """
        
        
        
        args = TrainingArguments(
            output_dir=self.output_dir,      # output directory
            num_train_epochs=self.max_epoch,              # total number of training epochs
            per_device_train_batch_size=1,   # batch size per device during training, can increase if memory allows
            per_device_eval_batch_size=1,    # batch size for evaluation, can increase if memory allows
            save_steps=500,                  # number of updates steps before checkpoint saves
            save_total_limit=5,              # limit the total amount of checkpoints and deletes the older checkpoints
            evaluation_strategy='steps',     # evaluation strategy to adopt during training
            eval_steps=100,                  # number of update steps before evaluation
            warmup_steps=100,                # number of warmup steps for learning rate scheduler
            weight_decay=0.01,               # strength of weight decay
            logging_dir='./logs',            # directory for storing logs
            logging_steps=10,
            learning_rate=5e-5
        )
        
        model = self.model
        model.to(self.device)
        no_decay = ["bias", "LayerNorm.weight"]
        t_total = len(self.train_dataloader) // args.num_train_epochs
        optimizer_grouped_parameters = [
            {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay},
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    

        for epoch in tqdm(range(args.num_train_epochs)):
            model.train()
            for step, batch in enumerate(tqdm(self.train_dataloader)):
                loss = self.training_step(batch=batch, model=model)
                loss.backward()
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
        
        self.model = model        
        return model
        

    
    def training_step(self, batch, model):
        """training step"""
        # batch = tuple(t.to(self.device) for t in batch)
        input_ids = batch["keywords_input_ids"].to(self.device)
        attention_mask = batch["keywords_attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)
        labels_attention_mask = batch["labels_attention_mask"].to(self.device)

        loss = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )['loss']
        # self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_size):
        """validation step"""
        input_ids = batch["keywords_input_ids"]
        attention_mask = batch["keywords_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_size):
        """test step"""
        input_ids = batch["keywords_input_ids"]
        attention_mask = batch["keywords_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        """configure optimizers"""
        return AdamW(self.parameters(), lr=0.0001)
    
    
    def load_model(self, model_dir: str = "outputs", use_gpu: bool = False):
        """
        loads a checkpoint for inferencing/prediction
        Args:
            model_dir (str, optional): path to model directory. Defaults to "outputs".
            use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
        """
        self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
        self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")

        if use_gpu:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                raise Exception(
                    "exception ---> no gpu found. set use_gpu=False, to use CPU"
                )
        else:
            self.device = torch.device("cpu")

        self.model = self.model.to(self.device)

    def save_model(self, model_dir="model"):
        """
        Save model to dir
        :param model_dir:
        :return: model is saved
        """
        path = f"{model_dir}"
        self.tokenizer.save_pretrained(path)
        self.model.save_pretrained(path)

    def predict(
        self,
        keywords: list,
        max_length: int = 512,
        num_return_sequences: int = 1,
        num_beams: int = 2,
        top_k: int = 50,
        top_p: float = 0.95,
        do_sample: bool = True,
        repetition_penalty: float = 2.5,
        length_penalty: float = 1.0,
        early_stopping: bool = True,
        skip_special_tokens: bool = True,
        clean_up_tokenization_spaces: bool = True,
        use_gpu: bool = False,
    ):
        """
        generates prediction for K2T model
        Args:
            Keywords (list): any keywords for generating predictions
            max_length (int, optional): max token length of prediction. Defaults to 512.
            num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
            num_beams (int, optional): number of beams. Defaults to 2.
            top_k (int, optional): Defaults to 50.
            top_p (float, optional): Defaults to 0.95.
            do_sample (bool, optional): Defaults to True.
            repetition_penalty (float, optional): Defaults to 2.5.
            length_penalty (float, optional): Defaults to 1.0.
            early_stopping (bool, optional): Defaults to True.
            skip_special_tokens (bool, optional): Defaults to True.
            clean_up_tokenization_spaces (bool, optional): Defaults to True.
            use_gpu: Defaults to True.
        Returns:
            str: returns predictions
        """
        if use_gpu:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                raise Exception(
                    "exception ---> no gpu found. set use_gpu=False, to use CPU"
                )
        else:
            self.device = torch.device("cpu")

        # source_text = " ".join(map(str, keywords))
        # source_text = keywords.split(" ")

        input_ids = self.tokenizer.encode(
            keywords, return_tensors="pt", add_special_tokens=True
        )

        input_ids = input_ids.to(self.device)
        self.model = self.model.to(self.device)
        generated_ids = self.model.generate(
            input_ids=input_ids,
            num_beams=num_beams,
            max_length=max_length,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            early_stopping=early_stopping,
            top_p=top_p,
            top_k=top_k,
            num_return_sequences=num_return_sequences,
        )
        preds = [
            self.tokenizer.decode(
                g,
                skip_special_tokens=skip_special_tokens,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            )
            for g in generated_ids
        ]
        return preds[0]

    def evaluate(self, test_df: pd.DataFrame, metrics: str = "rouge"):
        """

        :param test_df:
        :param metrics:
        :return: Output metrics for keytotext
        """
        metric = load_metric(metrics)
        input_text = test_df[test_df.columns[0]]
        references = test_df[test_df.columns[1]]
        predictions = [self.predict(x) for x in tqdm(input_text)]

        results = metric.compute(predictions=predictions, references=references)

        output = {
            "Rouge 1": {
                "Rouge_1 Low Precision": results["rouge1"].low.precision,
                "Rouge_1 Low recall": results["rouge1"].low.recall,
                "Rouge_1 Low F1": results["rouge1"].low.fmeasure,
                "Rouge_1 Mid Precision": results["rouge1"].mid.precision,
                "Rouge_1 Mid recall": results["rouge1"].mid.recall,
                "Rouge_1 Mid F1": results["rouge1"].mid.fmeasure,
                "Rouge_1 High Precision": results["rouge1"].high.precision,
                "Rouge_1 High recall": results["rouge1"].high.recall,
                "Rouge_1 High F1": results["rouge1"].high.fmeasure,
            },
            "Rouge 2": {
                "Rouge_2 Low Precision": results["rouge2"].low.precision,
                "Rouge_2 Low recall": results["rouge2"].low.recall,
                "Rouge_2 Low F1": results["rouge2"].low.fmeasure,
                "Rouge_2 Mid Precision": results["rouge2"].mid.precision,
                "Rouge_2 Mid recall": results["rouge2"].mid.recall,
                "Rouge_2 Mid F1": results["rouge2"].mid.fmeasure,
                "Rouge_2 High Precision": results["rouge2"].high.precision,
                "Rouge_2 High recall": results["rouge2"].high.recall,
                "Rouge_2 High F1": results["rouge2"].high.fmeasure,
            },
            "Rouge L": {
                "Rouge_L Low Precision": results["rougeL"].low.precision,
                "Rouge_L Low recall": results["rougeL"].low.recall,
                "Rouge_L Low F1": results["rougeL"].low.fmeasure,
                "Rouge_L Mid Precision": results["rougeL"].mid.precision,
                "Rouge_L Mid recall": results["rougeL"].mid.recall,
                "Rouge_L Mid F1": results["rougeL"].mid.fmeasure,
                "Rouge_L High Precision": results["rougeL"].high.precision,
                "Rouge_L High recall": results["rougeL"].high.recall,
                "Rouge_L High F1": results["rougeL"].high.fmeasure,
            },
            "rougeLsum": {
                "rougeLsum Low Precision": results["rougeLsum"].low.precision,
                "rougeLsum Low recall": results["rougeLsum"].low.recall,
                "rougeLsum Low F1": results["rougeLsum"].low.fmeasure,
                "rougeLsum Mid Precision": results["rougeLsum"].mid.precision,
                "rougeLsum Mid recall": results["rougeLsum"].mid.recall,
                "rougeLsum Mid F1": results["rougeLsum"].mid.fmeasure,
                "rougeLsum High Precision": results["rougeLsum"].high.precision,
                "rougeLsum High recall": results["rougeLsum"].high.recall,
                "rougeLsum High F1": results["rougeLsum"].high.fmeasure,
            },
        }
        return output