import os
from itertools import repeat
from typing import Dict, List, Tuple, Optional, Any, Union

from transformers.trainer import Trainer
from transformers.trainer_utils import has_length, seed_worker
from transformers.utils import is_datasets_available

from transformers.trainer_pt_utils import ShardSampler

import torch
from torch.utils.data import DataLoader, SequentialSampler
import torch.distributed as dist
import datasets
from loss import SimpleContrastiveLoss, DistributedContrastiveLoss

import logging
logger = logging.getLogger(__name__)

try:
    from grad_cache import GradCache
    _grad_cache_available = True
except ModuleNotFoundError:
    _grad_cache_available = False


TRAINING_ARGS_NAME = "training_args.bin"


class EmbeddingTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        self.batch_sampler = kwargs.pop('batch_sampler', None)
        super(EmbeddingTrainer, self).__init__(*args, **kwargs)
        self._dist_loss_scale_factor = 1.0
        if self.args.negatives_x_device and dist.is_initialized():
            self._dist_loss_scale_factor = dist.get_world_size() if self.args.loss_scale<=0 else self.args.loss_scale
        logger.info(f"Using loss scale: {self._dist_loss_scale_factor}")
        self._warmup_steps = self.args.get_warmup_steps(self.args.max_steps)
        logger.info(f"Warmup steps: {self._warmup_steps}")

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        return DataLoader(
            train_dataset,
            batch_sampler=self.batch_sampler,
            collate_fn=data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            worker_init_fn=seed_worker,
        )


    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        print('start saving checkpoint')
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")
        print('model type', type(self.model), self.model)
        self.model.save_pretrained_new(output_dir)
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def compute_loss(self, model, inputs, **kwargs):
        sub_batchsize = inputs.pop('sub_batchsize', None)
        outputs = model(
            inputs,
            sub_batchsize=sub_batchsize
        )
        self.log_scores(outputs.logits)
        self.log_weights(outputs.pos_weights, outputs.neg_weights)
        return outputs.loss

    @torch.no_grad()
    def log_weights(self, pos_weights, neg_weights):
        if self.state.global_step % self.args.logging_steps != 0:
            return
        mean_pos_weights = pos_weights.mean()
        mean_neg_weights = neg_weights.mean()
        mean_neg_pos_weights_ratio = (neg_weights / torch.clamp_min(pos_weights, min=1e-8)).mean()
        self.log({
            'mean_pos_weights': mean_pos_weights.item(), 'mean_neg_weights': mean_neg_weights.item(), 
            'mean_neg_pos_weights_ratio': mean_neg_pos_weights_ratio.item()
        })

    def log_scores(self, scores):
        if self.state.global_step % self.args.logging_steps != 0:
            return
        pos_scores = scores[:,:1]
        neg_scores = scores[:,1:]
        max_neg_scores = torch.max(neg_scores,dim=-1)[0]
        max_score_gap = torch.mean(pos_scores - max_neg_scores).item()
        mean_score_gap = torch.mean(pos_scores - neg_scores).item()
        mean_scores = torch.mean(scores).item()
        mean_pos_scores = torch.mean(pos_scores).item()
        self.log({'max_neg_gap': max_score_gap, 'mean_neg_gap':mean_score_gap, 'mean_scores': mean_scores, 'mean_pos_scores': mean_pos_scores})


    def training_step(self, *args):
        disable_x_device = self.args.contrastive_warmup and (self.state.global_step <= self._warmup_steps)
        negatives_x_device = self.args.negatives_x_device and not disable_x_device
        loss_scale_factor = self._dist_loss_scale_factor if negatives_x_device else 1.0
        return super(EmbeddingTrainer, self).training_step(*args) / loss_scale_factor


def split_dense_inputs(model_input: dict, chunk_size: int):
    assert len(model_input) == 1
    arg_key = list(model_input.keys())[0]
    arg_val = model_input[arg_key]

    keys = list(arg_val.keys())
    chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
    chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]

    return [{arg_key: c} for c in chunked_arg_val]


def get_dense_rep(x):
    if x.q_reps is None:
        return x.d_reps
    else:
        return x.q_reps


class GCTrainer(EmbeddingTrainer):
    def __init__(self, *args, **kwargs):
        logger.info('Initializing Gradient Cache Trainer')
        if not _grad_cache_available:
            raise ValueError(
                'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')
        super(GCTrainer, self).__init__(*args, **kwargs)

        loss_fn_cls = DistributedContrastiveLoss if self.args.negatives_x_device else SimpleContrastiveLoss
        loss_fn = loss_fn_cls(temperature=self.args.temperature)

        self.gc = GradCache(
            models=[self.model, self.model],
            chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_d_chunk_size],
            loss_fn=loss_fn,
            split_input_fn=split_dense_inputs,
            get_rep_fn=get_dense_rep,
            fp16=self.args.fp16,
            scaler=self.scaler if self.args.fp16 else None
        )

    def training_step(self, model, inputs) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)
        queries, documents = {'query': inputs['query']}, {'doc': inputs['doc']}

        _distributed = self.args.local_rank > -1
        self.gc.models = [model, model]
        loss = self.gc(queries, documents, no_sync_except_last=_distributed)

        return loss / self._dist_loss_scale_factor
