import copy 
import json 
import os 
import numpy as np 
import torch 
from torch .utils .data import Dataset ,DataLoader 
import torch .distributed as torch_distributed 
from transformers .tokenization_utils import PreTrainedTokenizer 
from pathlib import Path 
from llama_recipes .utils .distributed import print_rank_0 
from megatron_lm .megatron .global_vars import get_args ,set_sampler 
from torch .utils .data ._utils .collate import default_collate 

class InstructDataset (Dataset ):
    def __init__ (
    self ,
    tokenizer :PreTrainedTokenizer ,
    data_path :str ,
    )->None :
        args =get_args ()
        self .data_path :str =data_path 
        self .max_tokens =min (args .seq_length ,tokenizer .model_max_length )
        self .tokenizer =tokenizer 
        self .debug_mode =args .instruct_debug 

        self .system_prompt_role =args .system_prompt_role 
        self .system_prompt_content =args .system_prompt_content 

        dataset_dir =Path (self .data_path ).parent 
        index_cache_dir =dataset_dir /".index_cache"
        os .makedirs (index_cache_dir ,exist_ok =True )
        index_file_path =index_cache_dir /str (os .path .basename (self .data_path )).replace (".jsonl",".idx")
        self .index_file_path :str =str (index_file_path )
        try :
            with open (self .index_file_path ,"r",encoding ="utf-8")as f :
                self .indexes :list [int ]=[int (line .strip ())for line in f ]
        except Exception as e :
            print (f"index file error: {e }")
            exit (1 )

        self .length_bins =[0 ]*8 
        self .num_exceed_max_tokens =0 
        self .total_length :int =0 
        self .min_length :int =10 **9 
        self .max_length :int =0 

    def __len__ (self )->int :
        return len (self .indexes )

    def __getitem__ (self ,index :int )->dict [str ,torch .Tensor ]:
        IGNORE_INDEX =-100 
        with open (self .data_path ,"r",encoding ="utf-8")as file :
            offset :int =self .indexes [index ]
            file .seek (offset )
            try :
                line =file .readline ()
            except Exception as e :
                print (f"index={index }, offset={offset }, error={e }")
                exit (1 )
            try :
                conversations :dict [str ,list [dict [str ,str ]]|str ]=json .loads (line )
            except Exception as e :
                print (f"index={index }, offset={offset }, line={line }, error={e }")
                exit (1 )

        eod_token_id :int =self .tokenizer .encode ("<|end_of_text|>",add_special_tokens =False )[0 ]

        if "role"in conversations and conversations ["role"]=="next_token_prediction":
            prompt =[self .tokenizer .bos_token_id ]
            example =self .tokenizer .encode (
            conversations ["content"],add_special_tokens =True 
            )
            example +=[eod_token_id ]
        else :
            SYSTEM_PROMPT :list [dict [str ,str ]]=[
            {
            "role":self .system_prompt_role ,
            "content":self .system_prompt_content ,
            }
            ]

            prompt =self .tokenizer .apply_chat_template (
            conversation =SYSTEM_PROMPT +conversations ["input"],
            tokenize =True ,
            )
            example =self .tokenizer .apply_chat_template (
            conversation =SYSTEM_PROMPT +conversations ["input"]+[conversations ["output"]],
            tokenize =True ,
            )


        length_of_example =len (example )


        self .total_length +=length_of_example 
        self .min_length =min (self .min_length ,length_of_example )
        self .max_length =max (self .max_length ,length_of_example )


        bin_size =self .max_tokens /8.0 
        bin_index =int (length_of_example //bin_size )
        if bin_index >=8 :
            bin_index =7 
        self .length_bins [bin_index ]+=1 


        if length_of_example >self .max_tokens :
            self .num_exceed_max_tokens +=1 


        if self .debug_mode and length_of_example >self .max_tokens :
            print_rank_0 (
            f"[DEBUG] Example length={length_of_example }, which exceeds max_tokens={self .max_tokens }.\n"
            )



        if length_of_example >self .max_tokens :
            example =example [:self .max_tokens ]


        tensor_example =torch .tensor (example ,dtype =torch .int64 )



        padding_length :int =self .max_tokens -len (tensor_example )
        pad_token_id :int =(
        self .tokenizer .pad_token_id if self .tokenizer .pad_token_id is not None else self .tokenizer .eos_token_id 
        )
        if padding_length >0 :
            pad_tensor =torch .full ((padding_length ,),pad_token_id ,dtype =torch .int64 )
            tensor_example =torch .cat ((tensor_example ,pad_tensor ))


        labels =copy .deepcopy (tensor_example )



        prompt_len =min (len (prompt ),len (labels ))
        labels [:prompt_len ]=IGNORE_INDEX 


        labels [labels ==pad_token_id ]=IGNORE_INDEX 

        attention_mask =(tensor_example !=pad_token_id ).float ()


        label_mask =labels .ge (0 )
        if torch .all (label_mask ==0 ):
            random_index :int =np .random .randint (0 ,len (self .indexes ))
            return self .__getitem__ (random_index )



        dump_raw =bool (int (os .getenv ("DEBUG_DUMP_RAW","0")))
        meta ={
        "index":index ,
        "offset":offset ,
        "data_path":self .data_path ,
        }
        if dump_raw :
            meta ["raw"]=line 

        return {
        "input_ids":tensor_example ,
        "labels":labels ,
        "attention_mask":attention_mask ,
        "__meta__":meta ,
        }

    def print_length_bins (self ):

        total_samples =sum (self .length_bins )
        if total_samples ==0 :
            print ("No samples processed yet.")
            return 

        bin_size =self .max_tokens /8.0 
        print ("=== Length Distribution (len(example)) ===")
        for i ,count in enumerate (self .length_bins ):
            low =int (i *bin_size )+1 
            high =int ((i +1 )*bin_size )
            if i ==7 :
                print (f"{i }: {low }〜{self .max_tokens } tokens -> {count } samples")
            else :
                print (f"{i }: {low }〜{high } tokens -> {count } samples")


        exceed_percentage =(100.0 *self .num_exceed_max_tokens /total_samples )if total_samples >0 else 0.0 
        print (
        f"{len (self .length_bins )}: {self .max_tokens +1 }~: {self .num_exceed_max_tokens } samples"
        f" ({exceed_percentage :.2f}%)\n"
        )

        print (f"Total processed samples for stats: {total_samples } samples")
        avg =self .total_length /total_samples 
        print (f"--- Token length stats ---")
        print (f"min:  {self .min_length }")
        print (f"avg: {avg :.2f}")
        print (f"max:  {self .max_length }")


def worker_init_fn (worker_id :int )->None :
    import random 

    args =get_args ()

    worker_seed =args .seed +worker_id 
    np .random .seed (worker_seed )
    random .seed (worker_seed )

def collate_with_meta (batch ):

    keys =batch [0 ].keys ()
    out ={}
    metas =[]
    for k in keys :
        if k =="__meta__":
            metas =[sample ["__meta__"]for sample in batch ]
        else :
            out [k ]=default_collate ([sample [k ]for sample in batch ])
    out ["__meta__"]=metas 
    return out 


def get_instruction_tuning_dataloader (
tokenizer :PreTrainedTokenizer ,
data_path :str ,
train :bool =False ,
)->DataLoader :
    from llama_recipes .utils .sequence_length_warmup import CustomDistributedSampler 
    from llama_recipes .utils .checkpoint import load_sampler_state_dict 

    args =get_args ()

    instruction_dataset =InstructDataset (
    tokenizer =tokenizer ,
    data_path =data_path ,
    )

    if train :
        args .instruction_dataset_size =len (instruction_dataset )
        print_rank_0 (f"Instruction dataset size: {args .instruction_dataset_size }")

    train_sampler =CustomDistributedSampler (
    dataset =instruction_dataset ,
    rank =torch_distributed .get_rank (),
    num_replicas =torch_distributed .get_world_size (),
    shuffle =True ,
    seed =args .seed ,
    )

    if args .load :
        load_sampler_state_dict (sampler =train_sampler ,path =args .load )

    set_sampler (sampler =train_sampler )

    return DataLoader (
    instruction_dataset ,
    batch_size =args .micro_batch_size ,
    sampler =train_sampler ,
    num_workers =args .num_workers ,
    pin_memory =True ,
    drop_last =True ,
    worker_init_fn =worker_init_fn ,
    collate_fn =collate_with_meta ,
    )
