#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import argparse
import time
import math
import os, sys
import json
import itertools
import functools
import collections
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch import Tensor, device, dtype, nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch.nn.functional as F
torch.set_printoptions(threshold=100000)

import numpy as np

from gpu import (
    add_gpu_params, 
    parse_gpu, 
    distributed_opt, 
    distributed_gather, 
    distributed_sync, 
    cleanup
)

from exp_utils import create_exp_dir

from data_utils import FT_Dataset 
from model import GPT2Config, GPT2LMModel


parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')

add_gpu_params(parser)

parser.add_argument('--data', type=str, default='../data/wikitext-103',
                    help='location of the data corpus')

parser.add_argument('--batch_size', type=int, default=10,
                    help='batch size')

parser.add_argument('--seq_len', type=int, default=512,
                    help='number of tokens to predict')

parser.add_argument('--eval_len', type=int, default=256,
                    help='evaluation length')

parser.add_argument('--min_length', type=int, default=0,
                    help='minimum generation length')

parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
                    help='model names')

parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')

parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')

parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')

parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'), 
                    help='working folder')

parser.add_argument('--top_p', type=int, default=0.9, help='beam search size')

parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')

parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')

parser.add_argument('--eos_token_id', action='append', type=int, default=[50256], 
                    help='eos token id')

parser.add_argument('--output_file', type=str, default='top_p_prediction.jsonl', 
                    help='output file name')


def print_args(args):
    if args.rank == 0:
        print('=' * 100)
        for k, v in args.__dict__.items():
            print('        - {} : {}'.format(k, v))
        print('=' * 100)


def _calc_banned_ngram_tokens(
    prev_input_ids: Tensor, 
    num_hypos: int, 
    no_repeat_ngram_size: int, 
    cur_len: int
) -> None:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


def _postprocess_next_token_scores(
    scores,
    history,
    cur_len,
    batch_size,                            
    no_repeat_ngram_size=4,
    min_length=0,
    eos_token_id=None
):

    # score: batch_size * beam, vocab
    # set eos token prob to zero if min_length is not reached
    if eos_token_id is not None and cur_len < min_length:
        for eos in eos_token_id:
            scores[:, eos] = -float("inf")

    if no_repeat_ngram_size > 0 and history is not None:
        # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        banned_batch_tokens = _calc_banned_ngram_tokens(
                history, batch_size, no_repeat_ngram_size, cur_len
        )

        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

    return scores


def top_p(model, data_iter, args):
    model.eval()
    total_loss = 0.
    start_time = time.time()

    all_predictions = {}
    with torch.no_grad():
        for idx, data in enumerate(data_iter):
            data = {key: value for key, value in data.items()}

            _id = data['id'].to(args.device)
            _query = data['query'].to(args.device)

            output = None

            batch_size = _id.size(0)

            best_sequence = torch.zeros(
                (batch_size, args.eval_len), dtype=torch.long, device=_query.device
            )
            
            with torch.no_grad():
                for i in range(0, args.eval_len):
                    logits, _ = model(_query)
                    logits = logits[:, -1, :]

                    logits = _postprocess_next_token_scores(
                        logits,
                        best_sequence[:,:i],
                        i,
                        batch_size,                             
                        no_repeat_ngram_size=args.no_repeat_ngram_size,
                        min_length=args.min_length,
                        eos_token_id=args.eos_token_id,
                    )
                    
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                
                    # Remove tokens with cumulative probability above the threshold
                    sorted_indices_to_remove = cumulative_probs > args.top_p

                    # Replace logits to be removed with -inf in the sorted_logits
                    sorted_logits[sorted_indices_to_remove] = -float("inf")
                    # Then reverse the sorting process by mapping back sorted_logits to their original position
                    logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1))

                    pred_token = torch.multinomial(F.softmax(logits, -1), 1)

                    best_sequence[:, [i]] = pred_token


            with torch.no_grad():
                _id = distributed_gather(args, _id)
                output = distributed_gather(args, best_sequence)
                #score = distributed_gather(args, score)
                distributed_sync(args)

            if args.rank == 0:
                _id = _id.view(-1).cpu()
                output = output.view(-1, output.shape[-1]).cpu()
                #score = score.view(-1, score.shape[-1]).cpu()

                for _b in range(0, _id.shape[-1]):
                    _i = int(_id[_b].item())
                    all_predictions[_i] = {}
                    all_predictions[_i]['id'] = _i
                    all_predictions[_i]['predict'] = output[_b].tolist()
                    #all_predictions[_i]['score'] = score[_b].tolist()

                if idx % 10 == 0:
                    print('inference samples', idx)

    if args.rank == 0:
        pred_file = os.path.join(args.work_dir, args.output_file) 
        print('saving prediction file', pred_file)
        with open(pred_file, 'w') as writer:
            for _i in all_predictions:
                writer.write(json.dumps(all_predictions[_i]) + '\n')
    

if __name__ == '__main__':

    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '49152'
    # os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
    # os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
    # os.environ['OMPI_COMM_WORLD_RANK'] = '0'
    # os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    args = parser.parse_args()
    parse_gpu(args)
    print_args(args)
    
    if args.rank == 0:
        args.logging = create_exp_dir(args.work_dir)

    valid_data = FT_Dataset(
        args.data, args.batch_size, args.seq_len, args.eval_len, 
    )    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)
    valid_loader = DataLoader(
        valid_data, batch_size=args.batch_size, num_workers=0, shuffle=False, 
        pin_memory=False, drop_last=False, sampler=valid_sampler
    )

    enable_lora_attn = [True, True, True, True] # Q,K,V,out_proj
    enable_lora_mlp = True
    enable_lora_head = False

    if args.model_card == 'gpt2.sm':
        config = GPT2Config(
            n_embd=768, n_layer=12, n_head=12, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )
    elif args.model_card == 'gpt2.md':
        config = GPT2Config(
            n_embd=1024, n_layer=24, n_head=16, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )
    elif args.model_card == 'gpt2.lg':
        config = GPT2Config(
            n_embd=1280, n_layer=36, n_head=20, 
            lora_dim=args.lora_dim, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            enable_lora_attn=enable_lora_attn,
            enable_lora_mlp=enable_lora_mlp,
            enable_lora_head=enable_lora_head
        )

    lm_net = GPT2LMModel(config)
    if args.init_checkpoint is not None:
        print('loading model pretrained weight.')
        cp = torch.load(args.init_checkpoint, map_location=torch.device('cpu'))
        
        # in the case of ddp all keys start with 'module'
        if 'module' == list(cp['model_state_dict'].keys())[0].split('.')[0]:
            cp['model_state_dict'] = collections.OrderedDict([('.'.join(k.split('.')[1:]),v) for k,v in cp['model_state_dict'].items()])
        
        # adaptive ranks
        for n, m in lm_net.named_modules():
            if hasattr(m, 'lora_B'): # and (m.r != args.lora_dim)
                try:
                    lora_b = cp['model_state_dict'][n + '.lora_B'] # in case of mergedlinear, lora_A can be a multiple of lora dim
                    if lora_b.shape[1] != args.lora_dim:
                        m.change_lora_rank(lora_b.shape[1])
                except KeyError: # assuming lora dim = 0
                    m.change_lora_rank(0)

        lm_net.load_weight(cp)
        del cp

    lm_net = lm_net.cuda()

    print('model sampling ...')
    top_p(lm_net, valid_loader, args)
    distributed_sync(args)
    print('cleanup dist ...')
    cleanup(args)
