import json
import math
import os
import gc

import loguru
from accelerate import Accelerator
from transformers import PreTrainedModel
from accelerate.utils import set_seed
import torch
from torch.utils.data import DataLoader


from src.common.tf.loaders import load_tokenizer, load_model
from src.dataset.loader import DatasetLoader
from src.pipelines.base import BaseStrategy
from src.settings.datasets import DatasetStrategy
from src.constants import DISABLE_LOSS_LABEL

from src.settings.pipelines.inference.offline_metrics import PairOfflineMetricsSettings
from src.dataset.pair_preferences.pair_preference import PairPreferenceDataset
from src.dataset.pair_preferences.collators import PairPreferenceDataCollator


class OfflineMetricsInferenceStrategy(BaseStrategy):
    def read_jsonl(self, path: str) -> list[dict]:
        with open(path, 'r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]

    def _get_batch_logps(
        self,
        model: PreTrainedModel,
        batch: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        logits: torch.Tensor = model(batch["input_ids"], attention_mask=batch["attention_mask"]).logits.to(
            torch.float32
        )
        if logits.shape[:-1] != batch["labels"].shape:
            raise ValueError('Logits (batch and sequence length dim) and labels must have the same shape.')

        labels = batch["labels"][:, 1:].clone()
        logits = logits[:, :-1, :]
        loss_mask = labels != DISABLE_LOSS_LABEL

        labels[labels == DISABLE_LOSS_LABEL] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        avg_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
        log_prob = (per_token_logps * loss_mask).sum(-1)

        return {
            'cum_log_prob': log_prob.cpu(),
            'avg_log_prob': avg_log_prob.cpu(),
        }

    def run(self, experiment_settings: PairOfflineMetricsSettings) -> None:
        accelerator = Accelerator()
        set_seed(seed=0, device_specific=False)
        experiment_settings.save_path.mkdir(parents=True, exist_ok=True)

        model_inference_settings = experiment_settings.inference_settings
        tokenizer = load_tokenizer(
            model_inference_settings.tokenizer_settings,
            model_inference_settings.model_settings,
        )

        model = load_model(model_inference_settings.model_settings, tokenizer)
        model = accelerator.prepare_model(model, device_placement=True, evaluation_mode=True)
        model.eval()

        sft = None
        if model_inference_settings.sft_settings:
            sft = load_model(model_inference_settings.sft_settings, tokenizer)
            sft = accelerator.prepare_model(sft, device_placement=True, evaluation_mode=True)
            sft.eval()

        dataset = DatasetLoader[PairPreferenceDataset](PairPreferenceDataset).load_datasets(
            experiment_settings.dataset_settings,
            tokenizer=tokenizer,
            strategy=DatasetStrategy.TRAIN,
        )[0]

        batch_size = model_inference_settings.batch
        # data_collator = DataCollatorForTokenClassification(
        #     tokenizer=tokenizer, padding=True, return_tensors="pt", pad_to_multiple_of=8
        # )
        data_collator = PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True)

        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
        data_loader = accelerator.prepare(data_loader)

        original_records = self.read_jsonl(experiment_settings.dataset_settings.sources[0].records_path)
        assert len(dataset) == len(
            original_records
        ), f"dataset size ({len(dataset)}) != number of records ({len(original_records)})"

        with open(os.path.join(experiment_settings.save_path, "offline_metrics.jsonl"), "w") as f:
            for i, batch in enumerate(data_loader):
                loguru.logger.info("batch {}/{}", i + 1, len(data_loader))
                with torch.no_grad():

                    logprobs_w = self._get_batch_logps(model, batch['inputs_w'])
                    logprobs_l = self._get_batch_logps(model, batch['inputs_l'])
                    if sft:
                        sft_logprobs_w = self._get_batch_logps(sft, batch['inputs_w'])
                        sft_logprobs_l = self._get_batch_logps(sft, batch['inputs_l'])

                        assert len(sft_logprobs_w['cum_log_prob']) == len(logprobs_w['cum_log_prob'])
                        assert len(sft_logprobs_l['cum_log_prob']) == len(logprobs_l['cum_log_prob'])

                original_records_batch = original_records[i * batch_size : (i + 1) * batch_size]
                assert len(original_records_batch) == len(logprobs_w['cum_log_prob'])

                for rec_idx, (
                    record,
                    logprob_w,
                    norm_logprob_w,
                    sft_logprob_w,
                    sft_norm_logprob_w,
                    logprob_l,
                    norm_logprob_l,
                    sft_logprob_l,
                    sft_norm_logprob_l,
                ) in enumerate(
                    zip(
                        original_records_batch,
                        logprobs_w['cum_log_prob'],
                        logprobs_w['avg_log_prob'],
                        sft_logprobs_w['cum_log_prob'],
                        sft_logprobs_w['avg_log_prob'],
                        logprobs_l['cum_log_prob'],
                        logprobs_l['avg_log_prob'],
                        sft_logprobs_l['cum_log_prob'],
                        sft_logprobs_l['avg_log_prob'],
                    )
                ):
                    meta = record.get("meta", {})
                    meta.update(
                        {
                            "logprob_w": logprob_w.item(),
                            "norm_logprob_w": norm_logprob_w.item(),
                            "sft_logprob_w": sft_logprob_w.item(),
                            "sft_norm_logprob_w": sft_norm_logprob_w.item(),
                            "logprob_l": logprob_l.item(),
                            "norm_logprob_l": norm_logprob_l.item(),
                            "sft_logprob_l": sft_logprob_l.item(),
                            "sft_norm_logprob_l": sft_norm_logprob_l.item(),
                        }
                    )
                    record["meta"] = meta

                for item in original_records_batch:
                    if len(item) == 0:
                        continue
                    json.dump(item, f)
                    f.write("\n")

                del sft_logprobs_w
                del sft_logprobs_l
                del logprobs_w
                del logprobs_l

                gc.collect()
                torch.cuda.empty_cache()
