import os
from itertools import repeat
from typing import Dict, List, Tuple, Optional, Any, Union

from transformers.trainer import Trainer

import torch
from torch.utils.data import DataLoader
import torch.distributed as dist


import sys 
sys.path.append("..")
from losses import SimpleContrastiveLoss, DistributedContrastiveLoss

import logging
logger = logging.getLogger(__name__)

# from ..grad_cache import GradCache
# _grad_cache_available = True

try:
    from grad_cache import GradCache
    _grad_cache_available = True
except ModuleNotFoundError:
    _grad_cache_available = False


class DebugAlignDenseTrainer(Trainer):
    def __init__(self, delta_model=None, *args, **kwargs):
        super(DebugAlignDenseTrainer, self).__init__(*args, **kwargs)
        self.delta_model = delta_model
        self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1
        self.delta_model = delta_model

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        self.model.save(output_dir)
        
        if self.delta_model:
            logger.info("Saving delta model to %s", output_dir + "/delta_model")
            self.delta_model.save_finetuned(output_dir + "/delta_model")

        if self.delta_model:
            logger.info("Saving delta model to %s", output_dir + "/delta_model")
            self.delta_model.save_finetuned(output_dir + "/delta_model")

    def _prepare_inputs(
            self,
            inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...]
    ) -> List[Dict[str, Union[torch.Tensor, Any]]]:
        prepared = []
        for x in inputs:
            if isinstance(x, torch.Tensor):
                prepared.append(x.to(self.args.device))
            else:
                prepared.append(super()._prepare_inputs(x))
        return prepared

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=True,
            num_workers=self.args.dataloader_num_workers,
        )

    ## ------------ Prev Script ------------ ## 
    def compute_loss(self, model, inputs):
        query, passage = inputs
        return model(query=query, passage=passage).loss
    ## ------------  Prev Script ------------ ## 
    
    ## *****************************************
    ## beir eval add
    def prediction_step(
        self,
        model,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        assert prediction_loss_only == False
        assert ignore_keys is None
        loss = None
        labels = None
        
        inputs = {k: self._prepare_input(v) for k, v in inputs.items()}
        
        with torch.no_grad():
            logits = model(text=inputs)
            return (loss, logits, None)
    
    ## *****************************************
    
    
#     ## ------------ Modified ------------ ## 
#     def compute_loss(self, model, inputs):
#         query, passage, distil_scores = inputs
#         return model(query=query, passage=passage, distil_scores=distil_scores).loss
#     ## ------------ Modified ------------ ## 
    
    def training_step(self, *args):
        return super(DebugAlignDenseTrainer, self).training_step(*args) / self._dist_loss_scale_factor


def split_dense_inputs(model_input: dict, chunk_size: int):
    assert len(model_input) == 1
    arg_key = list(model_input.keys())[0]
    arg_val = model_input[arg_key]

    keys = list(arg_val.keys())
    chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
    chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]

    return [{arg_key: c} for c in chunked_arg_val]


def get_dense_rep(x):
    if x.q_reps is None:
        return x.p_reps
    else:
        return x.q_reps

