''' 
This module implements the training of a categorical abstraction for the donwstream task of text generation
'''
import pandas as pd
import transformers
import torch
from sklearn.metrics import classification_report, precision_recall_fscore_support
from argparse import ArgumentParser
from abstract_cf.text_generation.utils import load_dataset
import time

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
import plotly.express as px
import torch.nn.functional as F
import json
from transformers import EarlyStoppingCallback


class LearnedAbstractionPipeline:
    ''' 
    Simple class used for inference of the learned abstractions trained in this module 
    The goal of this class is to simplify the usage of huggingface models for inference,
    and simplify plotting. 
    '''

    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizerBase,
        lm_classifier: transformers.PreTrainedModel,
        id_to_label: dict[int, str],
        device: str = 'cpu'
    ):
        self.device = device
        self.tokenizer = tokenizer
        self.lm_classifier = lm_classifier.to(device)
        self.id_to_label = id_to_label
        self.label_to_id = {v: k for k, v in id_to_label.items()}

    def predict(self, input_text: list[str] | str) -> SequenceClassifierOutput:
        input_ids = self.tokenizer(
            input_text, 
            return_tensors='pt', 
            padding=True,
            truncation=True
        ).to(self.device)
        with torch.no_grad():
            output = self.lm_classifier(**input_ids)
        return output

    def __call__(self, input_text: list[str] | str, return_logits: bool = False) -> torch.tensor:    
        pred = self.predict(input_text)
        return pred.logits if return_logits else F.softmax(pred.logits, dim=-1)

    def plot_abstraction_classification(self, model_output: SequenceClassifierOutput):
        probabilities = F.softmax(model_output.logits, dim=-1).squeeze().tolist()
        labels = [self.id_to_label[i] for i in range(len(probabilities))]
        fig = px.bar(
            x=labels, 
            y=probabilities, 
            labels={'x': 'Abstraction', 'y': 'Probability'}, 
            title='Abstraction Classification Probabilities'
        )
        return fig

    def save(self, path: str):
        # Ensure the model is on CPU before saving
        self.lm_classifier.to('cpu')
        self.tokenizer.save_pretrained(path)
        self.lm_classifier.save_pretrained(path)
        with open(f'{path}/id_to_label.json', 'w') as f:
            json.dump(self.id_to_label, f)

    @classmethod
    def load(cls, save_path: str, device: str = 'cpu') -> "LearnedAbstractionPipeline":
        """
        Loads a saved pipeline from the given save_path.
        The directory at save_path should contain the tokenizer and model saved using
        the Huggingface transformers' save_pretrained method as well as an 'id_to_label.json'
        file.
        """
        tokenizer = AutoTokenizer.from_pretrained(save_path)
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model = AutoModelForSequenceClassification.from_pretrained(save_path)
        model.to(device)
        with open(f"{save_path}/id_to_label.json", "r") as f:
            id_to_label = json.load(f)
        # make sure the keys are integers
        id_to_label = {int(k): v for k, v in id_to_label.items()}
        return cls(tokenizer, model, id_to_label, device)


# Create datasets
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, label_encoder, tokenizer, max_length=512):
        self.texts = list(texts)
        self.labels = [label_encoder[label] for label in labels]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


def train_categorical_abstraction(
    train_df: pd.DataFrame,
    dev_df: pd.DataFrame,
    base_model: str='distilbert-base-uncased', # FacebookAI/xlm-roberta-base
    input_col: str='text',
    target_col: str='label',
    output_dir: str="./model_data/learned_abstractions/profession",
    num_train_epochs: int=1,
    per_device_train_batch_size: int=64,
    per_device_eval_batch_size: int=64,
    warmup_steps: int=100,
    weight_decay: float=0.01,
    logging_dir: str='./logs',
    logging_steps: int=100,
    evaluation_strategy: str="epoch",
    save_strategy: str="epoch",
    load_best_model_at_end: bool=True,
    metric_for_best_model: str="accuracy",
    report_to: str="none",
    early_stopping_patience: int=3,
):
    ''' 
    This function trains (or fine-tunes) a huggingface model on the input_col and target_col columns of the dataframe df. 
    This is useful for learning a categorical abstraction of the input_col. 
    For example, for the 'bios' dataset we want to learn to predict the profession from a given biographical text.
    '''

    # Set device
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # We will use an AutoModelForSequenceClassification model for this task
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
        base_model, 
        num_labels=train_df[target_col].nunique()
    ).to(device)
    tokenizer = transformers.AutoTokenizer.from_pretrained(base_model)

    # Compute the label mapping from the training set if not provided.
    # Sort the unique labels for a consistent mapping.
    sorted_labels = sorted(train_df[target_col].unique())
    id_to_label = {int(i): label for i, label in enumerate(sorted_labels)}
    
    # Create label encoder
    label_encoder = {label: i for i, label in id_to_label.items()} 

    print('Constructing datasets...')
    train_dataset = TextDataset(train_df[input_col].tolist(), train_df[target_col].tolist(), label_encoder, tokenizer)
    dev_dataset = TextDataset(dev_df[input_col].tolist(), dev_df[target_col].tolist(), label_encoder, tokenizer)

    print('Constructing training arguments...')
    # Training arguments
    training_args = transformers.TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        warmup_steps=warmup_steps,
        weight_decay=weight_decay,
        logging_dir=logging_dir,
        logging_steps=logging_steps,
        evaluation_strategy=evaluation_strategy,
        save_strategy=save_strategy,
        load_best_model_at_end=load_best_model_at_end,
        metric_for_best_model=metric_for_best_model,
        report_to=report_to,
    )

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = predictions.argmax(axis=-1)
        
        # Calculate overall accuracy
        accuracy = (predictions == labels).mean()
        
        # Calculate precision, recall, f1 for each class
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            labels, predictions, average='macro'
        )
        precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
            labels, predictions, average='micro'
        )
        
        # Get detailed per-class metrics
        report = classification_report(
            labels, 
            predictions, 
            output_dict=True,
            zero_division=0
        )
        
        # Reverse label encoder for class names
        reverse_label_encoder = {v: k for k, v in label_encoder.items()}
        per_class_metrics = {
            reverse_label_encoder[int(class_idx)]: {
                'precision': metrics['precision'],
                'recall': metrics['recall'],
                'f1': metrics['f1-score'],
                'support': metrics['support']
            }
            for class_idx, metrics in report.items()
            if class_idx.isdigit()
        }
        return {
            "accuracy": accuracy,
            "f1_macro": f1_macro,
            "f1_micro": f1_micro,
            "precision_macro": precision_macro,
            "precision_micro": precision_micro,
            "recall_macro": recall_macro,
            "recall_micro": recall_micro,
            "per_class_metrics": per_class_metrics
        }

    print('Training model...')
    # Initialize trainer with early stopping callback
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        compute_metrics=compute_metrics,
        id_to_label=id_to_label,
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)]
    )

    # Train and evaluate
    train_result = trainer.train()
    metrics = trainer.evaluate()

    print("\nTraining completed. Final metrics:")
    print(f"Training loss: {train_result.metrics['train_loss']:.4f}")
    print(f"Validation accuracy: {metrics['eval_accuracy']:.4f}")
    print(f"Macro F1-score: {metrics['eval_f1_macro']:.4f}")
    print(f"Micro F1-score: {metrics['eval_f1_micro']:.4f}")
    print("\nPer-class metrics:")
    for class_name, class_metrics in metrics['eval_per_class_metrics'].items():
        print(f"\n{class_name}:")
        print(f"  Precision: {class_metrics['precision']:.4f}")
        print(f"  Recall: {class_metrics['recall']:.4f}")
        print(f"  F1-score: {class_metrics['f1']:.4f}")
        print(f"  Support: {class_metrics['support']}")
    
    pipeline = LearnedAbstractionPipeline(tokenizer, model, id_to_label, device=device)
    pipeline.save(output_dir)
    return pipeline


def define_argument_parser():
    parser = ArgumentParser()
    parser.add_argument('--base_model', type=str, default='distilbert-base-uncased') # FacebookAI/xlm-roberta-base
    parser.add_argument('--input_col', type=str, default='text')
    parser.add_argument('--target_col', type=str, default='label')
    parser.add_argument('--output_dir', type=str, default='./model_data')   # TODO use some proper data path

    parser.add_argument('--num_train_epochs', type=int, default=1)
    parser.add_argument('--per_device_train_batch_size', type=int, default=32)
    parser.add_argument('--per_device_eval_batch_size', type=int, default=32)
    parser.add_argument(
        '--task',
        type=str,
        default='profession',
        choices=['profession', 'emotion', 'emotion_ekman'],
        help='Choose between profession or emotion.'
    )

    parser.add_argument('--report_to', type=str, default='clearml')
    parser.add_argument('--early_stopping_patience', type=int, default=3, help='Patience for early stopping.')
    return parser


class CustomTrainer(transformers.Trainer):
    """
    Custom Trainer that overrides the save_model method to save the full pipeline
    using our custom saving mechanism.
    """
    def __init__(
        self, 
        *args, 
        id_to_label: dict[int, str] = None, 
        tokenizer: transformers.PreTrainedTokenizerBase = None, 
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.id_to_label = id_to_label or {}
        self.tokenizer = tokenizer

    def save_model(self, output_dir: str = None, _internal_call: bool = False):
        # Override saving: use our full pipeline saving method.
        # NOTE: pipeline.save moves the model to CPU.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        original_device = self.model.device  # Remember the current device (e.g., cuda:0)
        pipeline = LearnedAbstractionPipeline(self.tokenizer, self.model, self.id_to_label, device=self.args.device)
        pipeline.save(output_dir)
        # Restore the model to its original device for continued training.
        self.model.to(original_device)
        print(f"Custom saved pipeline to {output_dir}")



if __name__ == '__main__':
    parser = define_argument_parser()
    args = parser.parse_args()

    if args.report_to == 'clearml':
        # set clearml environment variables
        import os 
        os.environ['CLEARML_PROJECT'] = 'abstract_counterfactuals'
        # the default setting in the huggingface pipeline is to reuse the same task id, 
        # which overwrites the previous run on the platform 
        # because of that, we add the timestamp to the task name
        os.environ['CLEARML_TASK'] = f'train_{args.task}_abstraction_{int(time.time())}'

    datasets = load_dataset(args.task)
    output_dir = f"./model_data/learned_abstractions/{args.task}"

    pipeline = train_categorical_abstraction(
        datasets['train'],
        datasets['dev'],
        args.base_model,
        input_col=args.input_col,
        target_col=args.target_col,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        per_device_train_batch_size=args.per_device_train_batch_size,
        num_train_epochs=args.num_train_epochs,
        output_dir=output_dir,
        report_to=args.report_to
    )
    print("Pipeline trained and saved at", output_dir)