import torch
from sentence_transformers import models
from sentence_transformers import SentenceTransformer, InputExample
from torch.utils.data import DataLoader
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss, MultipleNegativesRankingLoss, \
    CachedMultipleNegativesRankingLoss
from torch import nn
from tqdm import tqdm
from datasets import load_dataset
from transformers.trainer_pt_utils import LengthGroupedSampler


class MultipleNegativesRankingLoss_with_logging(torch.nn.Module):
    def __init__(self, loss_model, *args, print_steps=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_model = loss_model
        self.loss_list = []
        self.print_steps = print_steps
        self.counter = 0

    def forward(self, *args, **kwargs):
        self.counter += 1
        loss = self.loss_model(*args, **kwargs)
        self.loss_list.append(loss.item())

        if self.counter % self.print_steps == 0:
            print(f'step: {self.counter}, loss: {loss.item()}')
            # print average of last print_steps
            print(f'average loss: {sum(self.loss_list[-self.print_steps:]) / self.print_steps}')
        return loss


batch_size = 512
dataset = load_dataset("manu/embedding_data_v2_100k", split="train")
train_examples = []
for example in tqdm(dataset, desc="Building train examples"):
    t1 = example["text1"]
    t2 = example["text2"]
    train_examples.append(InputExample(texts=[t1, t2]))

# sort by length of the second text
train_examples = sorted(train_examples, key=lambda x: len(x.texts[1]))

lengths = [len(example.texts[1]) for example in train_examples]

# evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_examples, name="sts-dev")
print(f"Number of training examples: {len(train_examples)}")

# Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(train_examples,
                              num_workers=8,
                              batch_size=batch_size,
                              sampler=LengthGroupedSampler(batch_size=batch_size,
                                                           lengths=lengths))

# check if bfloat16 is supported
print("Checking if bfloat16 is supported")
if not torch.cuda.is_bf16_supported():
    print("bfloat16 is not supported on this device")
    dtype = torch.float32
else:
    print("bfloat16 is supported on this device")
    dtype = torch.bfloat16

print("Creating a SentenceTransformer model")
word_embedding_model = models.Transformer("croissantllm/CroissantCool-v0.2",
                                          max_seq_length=1024,
                                          tokenizer_args={"pad_token_id": 2, "pad_token": "</s>",
                                                          "add_eos_token": True},
                                          model_args={"torch_dtype": dtype, "device_map": "auto"})

pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=False,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False,
                               pooling_mode_lasttoken=False,
                               pooling_mode_mean_sqrt_len_tokens=False,
                               pooling_mode_weightedmean_tokens=True)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

print("Defining loss funcs")
# base_loss = CoSENTLoss(model=model)
base_loss = CachedMultipleNegativesRankingLoss(model=model, mini_batch_size=4)
# base_loss = MultipleNegativesRankingLoss(model=model)
base_loss = MultipleNegativesRankingLoss_with_logging(
    base_loss
)

# loss = MatryoshkaLoss(model=model, loss=base_loss, matryoshka_dims=[2048, 1024, 512, 256, 128, 64])
loss = base_loss

print("Training the model")

# Tune the model
model.fit(train_objectives=[(train_dataloader, loss)],
          epochs=5,
          warmup_steps=100,
          scheduler="WarmupCosine",
          # use_amp=True,
          checkpoint_save_steps=250,
          checkpoint_save_total_limit=1,
          # optimizer_params = {"lr": 1e-5},
          output_path="output/sentence_croissant_v8",
          )

print("Model trained")

model.save_to_hub("manu/sentence_croissant_v8")
