from functools import reduce
from logging import logMultiprocessing
import os
# import sys
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
#                                              os.path.pardir,os.path.pardir)))
from fairseq.utils import new_arange
from lm_eval.api.model import LM
from lm_eval import evaluator, tasks, utils
from lm_eval.api.model import CacheHook
from tqdm import tqdm
import torch.nn.functional as F
import math
# from lm_eval.tasks import ALL_TASKS
# from pretrain_gpt import model_provider
import numpy as np
import time
import random
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.core.enums import ModelType
from megatron.core import mpu
from megatron.training import setup_model_and_optimizer, get_model
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region

from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
import pickle
import json
from accelerate import (
    Accelerator,
    DistributedType,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)

from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
    get_dtype,
    pad_and_concat,
    stop_sequences_criteria,
)

from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model.distributed import DistributedDataParallel as LocalDDP
from megatron.model.module import Float16Module
# from deepspeed.runtime.pipe import schedule
from deepspeed.accelerator import get_accelerator
class EvalHarnessAdaptor(LM):
    def __init__(self, model, tokenizer):
        args = get_args()
        self.args = args
        self.model = model
        self.tokenizer = tokenizer
        self.VOCAB_SIZE = tokenizer.vocab_size
        self.mask_token_id = tokenizer.mask
        self.bos_token_id = tokenizer.bos
        self.pad_token_id = tokenizer.pad
        self.eos_token_id = tokenizer.eos
        self.inftype = args.inftype
        self._max_length = args.seq_length
        self.max_batch_size = 64
        # For ds we split into mini batches and then micro batches to keep pipelining api happy.
        # With Megatron we just go to micro_batches directly
        self._batch_size = args.micro_batch_size
        self._max_iter = args.max_iter
        self.ngarm = args.ngram
        self.beam = args.beam
        self.cache_hook = CacheHook(None)
        self._rank = 0
        self._world_size = 1
        self.is_main = args.rank == 0
        self.is_local_main = args.local_rank == 0
        self._device = get_accelerator().current_device_name()
        self.is_model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
        self.is_pipe_parallel = mpu.get_pipeline_model_parallel_world_size() > 1
        self.is_data_parallel = mpu.get_data_parallel_world_size() > 1
        self.adaptive_seq_len = args.adaptive_seq_len
        if self.is_data_parallel and args.moe_expert_parallel_size == 1: # For MoE model, allow a "fake data parallel" in order to partition model into multiple gpus
            raise NotImplementedError("Data parallelism is currently not supported for evaluation")

        self.is_last_stage = True if not self.is_pipe_parallel else mpu.is_pipeline_last_stage()  # only the last stage of the pipeline model will receive the logits

    @property
    def max_length(self):
        return self._max_length

    @property
    def batch_size(self):
        return self._batch_size

    @property
    def device(self):
        return self._device
    
    def loglikelihood_rolling(self, requests):
        import pdb; pdb.set_trace()
        return None
    
    def loglikelihood(self, requests):
        if self.inftype == "postion_beam":
            tokenizer = self.tokenizer
            model = self.model
            result = []
            device = self._device
            beam = self.beam
            max_seq_length = self._max_length
            for context, continuation in tqdm([req.args for req in requests]):
                source = context
                target = continuation
                tokenized_ids_source = tokenizer.tokenize(source)
                tokenized_ids_all = tokenizer.tokenize(source+target)
                tokenized_ids_target = tokenized_ids_all[len(tokenized_ids_source):]
                sample_len = len(tokenized_ids_target) + len(tokenized_ids_source)
                num_tokens = sample_len + 1
                if num_tokens > max_seq_length:
                    trun_index = max_seq_length - len(tokenized_ids_target) - 1
                    tokenized_ids_source = tokenized_ids_source[max_seq_length-trun_index:] 
                    num_tokens = max_seq_length
                pad_list = [tokenizer.pad for _ in range(max_seq_length-num_tokens)]
                prev_tokenized_ids_target = np.full(len(tokenized_ids_target), tokenizer.mask).tolist()
                input_ids = [[tokenizer.bos] + tokenized_ids_source + prev_tokenized_ids_target + pad_list]
                label_ids = [[tokenizer.bos] + tokenized_ids_source + tokenized_ids_target + pad_list]
                tokenized_ids_target_ts = torch.tensor(tokenized_ids_target).to(device)
                assert len(input_ids[0])==max_seq_length
                tokens = torch.tensor(input_ids).to(device)
                labels = torch.tensor(label_ids).to(device)
                padding_mask = tokens.ne(tokenizer.pad).type_as(tokens)
                index_mask = np.zeros((max_seq_length, max_seq_length))
                attention_index = len(tokenized_ids_source) + 1
                index_mask[:attention_index, :attention_index]=1
                index_mask[attention_index:num_tokens, :num_tokens]=1
                index_mask = torch.tensor(index_mask)
                index_mask = index_mask.unsqueeze(0).to(device)
                next_token_mask_list = tokenized_ids_target_ts.ne(tokenizer.mask).clone().repeat(beam,1)
                current_sample = tokens.clone()
                for token_index in range(len(tokenized_ids_target)):
                    current_token_mask = current_sample.eq(tokenizer.mask)
                    current_padding_mask = current_sample.ne(tokenizer.pad).type_as(current_sample)
                    current_index_mask = index_mask.clone().repeat(current_sample.size(0),1,1)
                    output_tensors = F.softmax(model(current_sample, current_padding_mask, current_index_mask)[0].detach(), dim=-1)
                    
                    if token_index==0:
                        output_confidence = torch.zeros_like(current_sample).type_as(output_tensors)

                    next_prediction_ids_confidence_dict = {}

                    for beamsample_index in range(output_tensors.size(0)): 
                        
                        current_confidence = output_confidence[beamsample_index]
                        current_mask = current_token_mask[beamsample_index].clone()
                        current_sample_ids = current_sample[beamsample_index].clone()
                        current_output_tensors = output_tensors[beamsample_index].clone()
                        current_scores = torch.gather(current_output_tensors,1,labels[0].unsqueeze(-1))
                        current_beam_able_number = current_mask.sum()
                        current_scores = current_scores.squeeze(-1).masked_fill_(~current_mask, 0.0)
                        
                        sorted_current_scores, sorted_current_index = torch.sort(current_scores, descending = True)


                        for score_index in range(current_beam_able_number):
                            if score_index >= beam:
                                continue
                            else:
                                new_prediction_ids = current_sample_ids.clone()
                                new_prediction_confidence = current_confidence.clone()
                                current_replace_index = sorted_current_index[score_index]
                                new_prediction_ids[current_replace_index] = label_ids[0][current_replace_index]
                                new_prediction_confidence[current_replace_index] = current_scores[current_replace_index]

                                next_prediction_ids_confidence_dict[(new_prediction_ids,new_prediction_confidence)] = sum(new_prediction_confidence.tolist())

                    next_prediction_ids_confidence_dict = {k: v for k, v in sorted(next_prediction_ids_confidence_dict.items(), key=lambda item: item[1], reverse=True)[:beam]}
                    
                    next_predict_ids_list = []
                    next_predict_confidence_list = []

                    for final_prediction_ids, final_prediction_confidence in next_prediction_ids_confidence_dict.keys():
                        next_predict_ids_list.append(final_prediction_ids.tolist())
                        next_predict_confidence_list.append(final_prediction_confidence.tolist())



                    current_sample = torch.tensor(next_predict_ids_list).type_as(current_sample_ids)
                    output_confidence = torch.tensor(next_predict_confidence_list).type_as(output_confidence)
                final_confidence = torch.log(output_confidence[:,attention_index:attention_index+len(tokenized_ids_target)]).sum(1).tolist()
                result.append((max(final_confidence),False))
            
            return result


    def generate_until(self, requests):
        import pdb; pdb.set_trace()
        return None 


def forward_sample(model, tokenizer, decoder_option, labels, padding_mask, index_mask):
    current_step = decoder_option["current_step"]
    max_step = decoder_option["max_step"]
    output_ids = decoder_option["current_ids"]
    output_confidence = decoder_option["current_confidence"]
    current_mask_index = output_ids.eq(tokenizer.mask)

    if current_step==0: 
        output_confidence = torch.zeros_like(output_ids).type_as(output_ids)
        decoder_option["mask_len_each_mask"] = torch.ceil(current_mask_index.sum(1) / max_step)    

    mask_len_each_mask = torch.min(decoder_option["mask_len_each_mask"], current_mask_index.sum(1))

    output_tensors = F.log_softmax(model(output_ids, padding_mask, index_mask)[0].detach(), dim=-1)
    current_scores = torch.gather(output_tensors, 2, labels.unsqueeze(-1)).squeeze(-1)
    current_scores.masked_fill_(~current_mask_index, float("-inf"))
    output_confidence = output_confidence.type_as(current_scores)
    skeptical_mask = _skeptical_unmasking(
        current_scores, current_mask_index,  mask_len_each_mask, 
    )

    output_ids.masked_scatter_(skeptical_mask,labels[skeptical_mask])
    output_confidence.masked_scatter_(skeptical_mask, current_scores[skeptical_mask])
    decoder_option["current_ids"] = output_ids.detach()
    decoder_option["current_confidence"] = output_confidence.detach()

    return decoder_option

def _skeptical_unmasking(output_scores, output_masks, n):
    sorted_index = output_scores.sort(-1, descending=True)[1]
    boundary_len = n.unsqueeze(-1).long()
    skeptical_mask = new_arange(output_masks) < boundary_len
    return skeptical_mask.scatter(1, sorted_index, skeptical_mask)

import deepspeed
from megatron.initialize import initialize_megatron
import megatron
from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.model import GeBertModel
from megatron.initialize import set_jit_fusion_options
from megatron.checkpointing import load_checkpoint

def model_provider(pre_process=True, post_process=False):
    """Build the model."""

    print_rank_0('building GEBERT model ...')
    args = get_args()
    config = core_transformer_config_from_args(args)
    model = GeBertModel(
            config=config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=True)
    return model

def tasks_args(parser):
    """Provide extra arguments required for tasks."""
    group = parser.add_argument_group(title='Evaluation options')
    group.add_argument('--task_list', type=str, default = "all", help='Either "all" or comma separated list of tasks.')
    group.add_argument('--results_path', type=str, default = "./results.json", help='Path to where the results will be stored.')
    group.add_argument('--num_fewshot', type=int, default = 0, help='Number of few-shot prompts.')
    group.add_argument('--sentence-split-type', type=str, default='direct', help='How to split the source and target.')
    group.add_argument('--masked-type', type=str, default='lazy_uniform', help='The masking type.')
    group.add_argument('--adaptive_seq_len',  default = False, action='store_true',
                       help='Should the sequence length be adapted to the batch during evaluation, if in fp16 the results will be slightly different due to numerical errors but greatly speed up evaluation.')
    group.add_argument('--eval_fp32',  default = False, action='store_true', help='Should the evaluation run in fp32')
    group.add_argument('--inftype',type=str, default='diff_gram_based', help='')
    group.add_argument('--max-iter',type=int, default=1)
    group.add_argument('--length-predict', action='store_true', help='')  
    group.add_argument('--ngram',type=int, default=1)
    group.add_argument('--beam',type=int, default=1)
    return parser


def main():
    ds_dict = {
        "train_batch_size" : 1,
        "train_micro_batch_size_per_gpu": 1,
        "steps_per_print": 10,

        "zero_optimization": {
        "stage": 0
        },

        "gradient_clipping": 1.0,
        "prescale_gradients": False,

        "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 500,
        "hysteresis": 2,
        "min_loss_scale": 1,
        "initial_scale_power": 11
        },

        "bf16": {
        "enabled": False
        },

        "wall_clock_breakdown" : False
        }
    


    start = time.time()
    # import pdb; pdb.set_trace()
    device = torch.device("cuda")
    initialize_megatron(extra_args_provider=tasks_args, args_defaults={'tokenizer_type': 'HFTokenizer'})

    args = get_args()
    
    tokenizer = get_tokenizer()
    model = get_model(model_provider)
    if args.deepspeed:
        if get_accelerator().device_name() == 'cuda':
            set_jit_fusion_options()
        args.deepspeed_config_dict = ds_dict
        model, optimizer, _, opt_param_scheduler = deepspeed.initialize(
                    model=model[0],
                    optimizer=None,
                    args=args,
                    lr_scheduler=None,
                    mpu=None,
                    config=args.deepspeed_config_dict,
                )
        _ = load_checkpoint([model], None, None, strict=True, load_only_weights=False)
    else:
        _ = load_checkpoint(model, None, None, strict=False, load_only_weights=False)
        assert len(model) == 1, "Above condition should have caught this"
        model = model[0].to(device)

    model.eval()
    task_list = args.task_list.split(',')
    task_dict = tasks.get_task_dict(task_list)
    adaptor = EvalHarnessAdaptor(model, tokenizer)
    
    results = evaluator.evaluate(adaptor, task_dict)
    print(results["results"])
    print(results["n-shot"])



if __name__ == '__main__':
    main()
