import os
from dataclasses import dataclass
from typing import Dict, Optional
import torch
from torch import nn, Tensor
from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoTokenizer

from transformers.file_utils import ModelOutput
from transformers import TrainingArguments
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
# AutoPeftModel
from peft import PeftConfig, get_peft_model
from peft import get_peft_model_state_dict


import logging
import random
from transformers.trainer import Trainer

from huggingface_hub.hf_api import HfFolder 

from rerank_arguments import ModelArguments , DataArguments

from typing import List, Tuple, Union, Any
from dataclasses import dataclass
from transformers import PreTrainedTokenizer


from datasets import load_dataset
from torch.utils.data import Dataset

from transformers import (
    HfArgumentParser,
    set_seed,
)
from dataclasses import dataclass, field

login_token = "Enter your login token"
HfFolder.save_token(login_token)
logger = logging.getLogger(__name__)


from torch.utils.data import DataLoader
from tqdm import tqdm
from contextlib import nullcontext
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer




import matplotlib.pyplot as plt

# Extract loss values from the trainer's log history
def plot_training_vs_val_loss_old(trainer, model_name):
    log_history = trainer.state.log_history
    train_loss = [entry["loss"] for entry in log_history if "loss" in entry]
    val_loss = [entry["val_loss"] for entry in log_history if "val_loss" in entry]
    epochs = range(1, len(train_loss) + 1)
    print(f"train_loss: {train_loss}")
    print(f"val_loss: {val_loss}")
    print(f"epochs: {epochs}")
    # Plot the training loss
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_loss, label="Training Loss")
    plt.plot(epochs, val_loss, label='Validation Loss')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Vs Validation Loss Curve")
    plt.legend()
    plt.grid()
    plt.show()
    # plt.plot(train_loss, label='Training Loss')
    # plt.plot(val_loss, label='Validation Loss')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.title('Training vs Validation Loss')
    # plt.legend()
    # plt.show()
    # # Save the plot
    plt.savefig(f"training_val_loss_curve_old_{model_name}.png")

def plot_training_vs_val_loss(trainer, model_name):
    """
    Plot training loss vs validation loss from the trainer's logs.
    """
    if not trainer.state.log_history:
        print("No training logs found.")
        return

    training_loss = []
    eval_loss = []
    steps = []

    for log in trainer.state.log_history:
        print(f"log: {log}")
        if 'loss' in log:
            training_loss.append(log['loss'])
            steps.append(log.get('step', len(steps)))
        if 'eval_loss' in log:
            eval_loss.append(log['eval_loss'])

    plt.figure(figsize=(10, 6))
    if training_loss:
        plt.plot(steps[:len(training_loss)], training_loss, label='Training Loss', marker='o')
    if eval_loss:
        eval_steps = [s for i, s in enumerate(steps) if i < len(eval_loss)]
        plt.plot(eval_steps, eval_loss, label='Validation Loss', marker='x')

    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    plt.savefig(f"training_val_loss_curve_{model_name}.png")


@dataclass
class RerankerOutput(ModelOutput):
    loss: Optional[Tensor] = None
    scores: Optional[Tensor] = None

class RerankerModel(nn.Module):
    TRANSFORMER_CLS = AutoModelForSequenceClassification

    def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
        super().__init__()
        self.config = hf_model.config
        self.hf_model = hf_model
        self.train_batch_size = train_batch_size
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.mse = nn.MSELoss(reduction='mean')
        # if train_batch_size:
        #     self.register_buffer(
        #         'target_label',
        #         torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
        #     )
        # if train_batch_size:
        #     self.register_buffer(
        #         'target_label',
        #         torch.zeros(self.train_batch_size, dtype=torch.float, device=self.hf_model.device)
        #     )
        for name, param in self.hf_model.named_parameters():
            # for some reason, ds zero 3 left some weights empty
            if 'modules_to_save' in name and param.numel() == 0:
                logger.warning(f'parameter {name}, shape {param.shape} is empty')
                param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
                logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))

    def forward(self, pair: Dict[str, Tensor] = None):
        # logger.info(f"pair:{pair}")
        
        ranker_logits = self.hf_model(**pair, return_dict=True).logits
        # print(f"This has to throw error: {100/0}")
        if self.train_batch_size:
            grouped_logits = ranker_logits.view(self.train_batch_size, -1)
            # loss = self.cross_entropy(grouped_logits, self.target_label)
            target_label  = pair["labels"]
            target_label = target_label.view(self.train_batch_size, -1)

            # print(f"target_label: {target_label}")
            # print(f"ranker_logits: {ranker_logits}")
            loss = self.mse(grouped_logits, target_label)
            print(f"loss: {loss}")
            return RerankerOutput(
                loss = loss,
                scores = ranker_logits
            )
        print(f"ranker_logits: {ranker_logits}")
        return RerankerOutput(
            loss = None,
            scores = ranker_logits
        )
    
    def gradient_checkpointing_enable(self, **kwargs):
        # if llama model uncomment next line
        self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
        # if pythia model uncomment next line
        # self.hf_model.base_model.gradient_checkpointing_enable(**kwargs)
        # self.hf_model.gradient_checkpointing_enable(**kwargs)
        # pass


    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            train_args: TrainingArguments,
            **hf_kwargs,
    ):
        # with deepspeed.zero.Init() :
        base_model = cls.TRANSFORMER_CLS.from_pretrained(
            model_args.model_name_or_path,
            num_labels=1,
            **hf_kwargs,
        )
        if base_model.config.pad_token_id is None:
            base_model.config.pad_token_id = 0
        if model_args.lora or model_args.lora_name_or_path:
            if train_args.gradient_checkpointing:
                base_model.enable_input_require_grads()
            if model_args.lora_name_or_path:
                lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
                lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
            else:
                lora_config = LoraConfig(
                    base_model_name_or_path=model_args.model_name_or_path,
                    task_type=TaskType.SEQ_CLS,
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    lora_dropout=model_args.lora_dropout,
                    target_modules=model_args.lora_target_modules.split(','),
                    inference_mode=False,
                    modules_to_save = model_args.lora_modules_to_save.split(',') if model_args.lora_modules_to_save else None,
                )
                lora_model = get_peft_model(base_model, lora_config)
            model = cls(
                hf_model=lora_model,
                train_batch_size=train_args.per_device_train_batch_size,
            )
        else:
            model = cls(
                hf_model=base_model,
                train_batch_size=train_args.per_device_train_batch_size,
            )
        return model

    @classmethod
    def load(cls,
            model_name_or_path: str,
            lora_name_or_path: str = None,
            **hf_kwargs):
        
        # Load LoRA config first
        peft_config = PeftConfig.from_pretrained(model_name_or_path)
        # logger.info(f"peft_config: {peft_config}")
        base_model = cls.TRANSFORMER_CLS.from_pretrained(peft_config.base_model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ignore_mismatched_sizes=True)
        if base_model.config.pad_token_id is None:
            base_model.config.pad_token_id = 0
        if lora_name_or_path:
            lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
            lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config, ignore_mismatched_sizes=True )
            #  ,ignore_mismatched_sizes=True
            lora_model = lora_model.merge_and_unload()
            model = cls(
                hf_model=lora_model,
            )
        else:
            model = cls(
                hf_model=base_model,
            )
        return model

    def save(self, output_dir: str):
        self.hf_model.save_pretrained(output_dir)
    
    @classmethod
    def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
        hf_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, *model_args, **kwargs)
        return cls(hf_model, **kwargs)
    
    def create_or_update_model_card(self, output_dir):
        # Implement the method to create or update the model card
        with open(f"{output_dir}/README.md", "w") as f:
            f.write("# Model Card\n")
            f.write("This is a placeholder model card for the RerankerModel.\n")
            f.write("\n## Model Details\n")
            f.write("Details about the model architecture, training data, and usage.\n")
            f.write("\n## Usage\n")
            f.write("Instructions on how to use the model.\n")
            f.write("\n## Training Data\n")
            f.write("Information about the training data used.\n")
            f.write("\n## Evaluation\n")
            f.write("Evaluation metrics and results.\n")
        
        

class RerankerTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super(RerankerTrainer, self).__init__(*args, **kwargs)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        self.model.save(output_dir)
        
        if self.args.is_deepspeed_zero3_enabled:
            if state_dict is None:
                state_dict = self.model.state_dict()
            prefix = 'hf_model.'
            assert all(
                k.startswith(prefix) or k == "target_label"
                for k in state_dict.keys()
            ), list(state_dict.keys())
            state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
            lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict)
            if self.args.process_index <= 0:
                torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
                print(f"Save adapter model at {output_dir}")


    def compute_loss(self, model, inputs, num_items_in_batch = None):
        # inputs = {k: v.to(self.device) for k, v in inputs.items()}  # Move inputs to the correct device
        outputs = model(inputs)
        loss = outputs.loss / num_items_in_batch if num_items_in_batch else outputs.loss
    
        return loss
    


@dataclass
class RerankerTrainCollator:
    data_args: DataArguments
    tokenizer: PreTrainedTokenizer

    def __call__(self, features : List[Tuple[str, str, str]]): #  #List[Dict[str, Any]] ) : #List[Dict[str, Union[List[str], List[float]]]]):
        """
        Collate function for training.
        :param features: list of pairs 
        {
            'formated_pair': formated_pair,
            'labels' : query_coverage_ratio
        }
        :return: tokenized pairs
        """
        all_pairs = []
        all_labels = []
        # print(f"features: {features}")
        # for pairs in features['formated_pair']:
        #     all_pairs.extend(pairs)
        # for label in features['labels']:
        #     all_labels.extend(label)
        
        for batch in features:
            for pair in batch:
                all_pairs.append(pair[0])
                all_labels.append(float(pair[1]))
            
        tokenized_pairs = self.tokenizer(
            all_pairs,
            padding=False, 
            truncation=True,
            max_length=self.data_args.rerank_max_len-1 if self.data_args.append_eos_token else self.data_args.rerank_max_len,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # Option 1: Use eos_token as pad_token

        if self.data_args.append_eos_token:
            tokenized_pairs['input_ids'] = [p + [self.tokenizer.eos_token_id] for p in tokenized_pairs['input_ids']]
        
        pairs_collated = self.tokenizer.pad(
            tokenized_pairs,
            padding=True, 
            pad_to_multiple_of=self.data_args.pad_to_multiple_of,
            return_attention_mask=True,
            return_tensors='pt',
        )
        # Convert labels to tensor
        labels_tensor = torch.tensor(all_labels, dtype=torch.float)
        # Add labels to the collated_pairs dictionary
        pairs_collated['labels'] = labels_tensor
        
        return pairs_collated




def format_pair(query: str, passage: str, title: str, query_prefix: str, passage_prefix: str):
    title = title.replace('-', ' ').strip()
    return f'{query_prefix} {query} {passage_prefix} {title} {passage}'.strip()


def calculate_covered_query_terms(query, passage_text):
    """Calculate covered query terms and ratio for a passage."""
    vectorizer = CountVectorizer()
    documents = [passage_text, query]  # First document is passage, second is query
    doc_term_matrix = vectorizer.fit_transform(documents)
    
    # Get document vectors
    passage_vector = doc_term_matrix[0].toarray().flatten()
    query_vector = doc_term_matrix[1].toarray().flatten()
    
    
    
    # Count query terms
    total_query_terms = (query_vector > 0).sum()
    
    # Count covered query terms (terms that appear in both query and passage)
    covered_terms = np.logical_and(passage_vector > 0, query_vector > 0).sum()
    
    # Calculate ratio
    covered_ratio = covered_terms / total_query_terms if total_query_terms > 0 else 0.0
    
    # Calculate normalized term frequency 
    # normalized_term_frequency = np.sum(passage_vector) / len(passage_vector) if len(passage_vector) > 0 else 0.0
    # Compute TF statistics
    tf_docs = np.array(doc_term_matrix.toarray())
    
    # Document length (stream length)
    doc_lengths = tf_docs.sum(axis=1)

    normalized_term_frequency = tf_docs / doc_lengths[:, np.newaxis]
    # Get normalized tf for passage only (first document)
    normalized_tf_passage = normalized_term_frequency[0]

    # Aggregate to a scalar
    normalized_tf_value = normalized_tf_passage.mean()
    # print(normalized_term_frequency)
    # # Ensure normalized_term_frequency is a scalar
    # if isinstance(normalized_term_frequency, np.ndarray):
    #     if normalized_term_frequency.size == 1:
    #         normalized_term_frequency = float(normalized_term_frequency.item())  # Extract scalar
    #     else:
    #         raise ValueError(f"Expected a scalar or single-element array, but got: {normalized_term_frequency}")

    return {
        "covered_query_terms": int(covered_terms),
        "total_query_terms": int(total_query_terms),
        "covered_query_term_ratio": float(covered_ratio), 
        "normalized_term_frequency" : float(normalized_tf_value)
    }


def augment_dataset_with_coverage(dataset):
    """Add coverage statistics to the dataset."""
    
    def process_group(example):
        query = example['query']
        
        # Process positive passages
        positive_coverage = []
        for pos_psg in example['positive_passages']:
            coverage = calculate_covered_query_terms(query, pos_psg['text'])
            positive_coverage.append(coverage)
        
        # Process negative passages
        negative_coverage = []
        for neg_psg in example['negative_passages']:
            coverage = calculate_covered_query_terms(query, neg_psg['text'])
            negative_coverage.append(coverage)
        
        # Add new columns
        return {
            'positive_passage_coverage': positive_coverage,
            'negative_passage_coverage': negative_coverage
        }
    
    # Apply the processing function to the dataset
    augmented_dataset = dataset.map(
        process_group,
        desc="Computing query coverage statistics",
        num_proc=32  # You can adjust this based on your CPU
    )
    
    return augmented_dataset



class RerankerTrainDataset(Dataset):
    def __init__(self, data_args: DataArguments, trainer = None, subset_ratio = None, val_ds = False, label_key = "covered_query_term_ratio"):
        self.data_args = data_args
        self.label_key = label_key
        self.dataset = load_dataset( # train_data # dataset
            self.data_args.dataset_name,
            self.data_args.dataset_config,
            data_files=self.data_args.dataset_path,
            split=self.data_args.dataset_split,
            cache_dir=self.data_args.dataset_cache_dir,
            trust_remote_code=True
        )
        # # Sample subset if ratio provided if subset_ratio is 0.20 the just take later 0.75 of it
        if subset_ratio:
            total_size = len(self.dataset)
            subset_size = int(total_size * subset_ratio)
            indices = random.sample(range(total_size), subset_size)
            # split indices into train and val 
            if val_ds :
                # val_indices = indices[:int(subset_size*0.25)]
                indices = indices[int(subset_size*0.75):]
                self.train_data = self.dataset.select(indices)
                logger.info(f"Sampled {int(subset_size*0.25)} examples for validation")
            
            else :
                self.train_data = self.dataset.select(indices)
            logger.info(f"Sampled {subset_size} examples ({subset_ratio*100}% of {total_size})")
        else :
            self.train_data = self.dataset
        
        
        
            
        
        self.train_data = augment_dataset_with_coverage(self.train_data)
        
        if self.data_args.dataset_number_of_shards > 1:
            # print(f"Sharding dataset into {self.data_args.dataset_number_of_shards} shards")
            self.encode_data = self.encode_data.shard(
                num_shards=self.data_args.dataset_number_of_shards,
                index=self.data_args.dataset_shard_index,
            )
        self.trainer = trainer

    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, item) -> Tuple[str, List[List[str]]] : #Dict[str, Union[List[str], List[float]]] :
        group = self.train_data[item]
        epoch = int(self.trainer.state.epoch)

        _hashed_seed = hash(item + self.trainer.args.seed)

        query = group['query']
        group_positives = group['positive_passages']
        group_negatives = group['negative_passages']

        formated_pair = []
        query_coverage_ratio = []

        if self.data_args.positive_passage_no_shuffle:
            pos_psg = group_positives[0]
        else:
            pos_psg = group_positives[(_hashed_seed + epoch) % len(group_positives)]
        
        pos_coverage = calculate_covered_query_terms(query, pos_psg['text'])
        query_coverage_ratio.append(pos_coverage[self.label_key])
        
        formated_pair.append([format_pair(query, pos_psg['text'], pos_psg['title'], self.data_args.query_prefix, self.data_args.passage_prefix), str(pos_coverage[self.label_key])])
        
        negative_size = self.data_args.train_group_size - 1
        if len(group_negatives) < negative_size:
            negs = random.choices(group_negatives, k=negative_size)
        elif self.data_args.train_group_size == 1:
            negs = []
        elif self.data_args.negative_passage_no_shuffle:
            negs = group_negatives[:negative_size]
        else:
            _offset = epoch * negative_size % len(group_negatives)
            negs = [x for x in group_negatives]
            random.Random(_hashed_seed).shuffle(negs)
            negs = negs * 2
            negs = negs[_offset: _offset + negative_size]

        # for neg_psg in negs:
        #     formated_pair.append(format_pair(query, neg_psg['text'], neg_psg['title'], self.data_args.query_prefix, self.data_args.passage_prefix))
        
        # Format negative passages and calculate covered query terms
        for neg_psg in negs:
            formatted_text = format_pair(
                query, 
                neg_psg['text'], 
                neg_psg['title'], 
                self.data_args.query_prefix, 
                self.data_args.passage_prefix
            )
            
            # Calculate covered query terms for negative passage
            neg_coverage = calculate_covered_query_terms(query, neg_psg['text'])
            
            formated_pair.append([formatted_text, str(neg_coverage[self.label_key])])
            # query_coverage_ratio.append()
            
        # print(f"formated_pair: {formated_pair}")
        # print(f"query_coverage_ratio: {query_coverage_ratio}")
        # return formated_pair
        # return formated_pair, query_coverage_ratio
        # ans =  {
        #     'formated_pair': formated_pair,
        #     'labels' : query_coverage_ratio
        # }
        
        # print(f"Returning: {ans}")
        
        return formated_pair


@dataclass
class CustomTrainingArguments(TrainingArguments):
    output_dir: str = field(
        default='./pythiatest', metadata={"help": "huggingface dataset name"}
    )
    is_deepspeed_zero3_enabled :bool = field(
        default=False, metadata={"help": "Make this false if not using deepspeed zero3"}
    )



def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
    # model_args, data_args, training_args = parser.parse_dict(manual_args)
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments
    
    tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=None,
            trust_remote_code=True
        )
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'right'
    # print(f"\n\n\n\nGpus used: {training_args.n_gpu}\n\nDevice:  {training_args.device}\n\n\n\n")
    
    model = RerankerModel.build(
        model_args,
        training_args,
        cache_dir=model_args.cache_dir,
        trust_remote_code=True,
    )
    model.to(training_args.device)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("MODEL parameters %s", model_args)
    
    

    set_seed(training_args.seed)
    label_key = "normalized_term_frequency"
    
    train_dataset = RerankerTrainDataset(data_args, subset_ratio=0.50*0.75, label_key = label_key)
    train_collator = RerankerTrainCollator(data_args, tokenizer)

    val_dataset = RerankerTrainDataset(data_args, subset_ratio=0.50, val_ds=True, label_key= label_key)

    trainer = RerankerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=train_collator, 
        eval_dataset=val_dataset,
        
        
    )
    train_dataset.trainer = trainer

    trainer.train()  # TODO: resume training
    
    # Call the function after training 
    try :
        
        ft_model_name = training_args.output_dir.split('/')[-1]
        plot_training_vs_val_loss(trainer, ft_model_name)
    except Exception as e:
        print(f"Error in plotting training vs validation loss: {e}")
    
    trainer.save_model()





if __name__ == "__main__":
    main()

