import torch
import json
import argparse
import logging
import jsonlines
import os

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, GenerationConfig

from accelerate import PartialState
from accelerate.utils import gather_object
from accelerate.logging import get_logger

from curriculum_distill.utils import create_parent_directory

# logger = get_logger(__name__)
logging.basicConfig(level=logging.INFO)
# logger.setLevel(logging.DEBUG)


def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def write_data_to_json_file(data, file_path):
    try:
        with open(file_path, 'w') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        logging.info(f"Data successfully written to {file_path}")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def write_data_to_jsonlines_file(data, file_path, mode='a'):
    try:
        with jsonlines.open(file_path, mode=mode) as writer:
            writer.write_all(data)
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def main(config):
    # Start up the distributed environment without needing the Accelerator.
    distributed_state = PartialState()

    model_name = config["inference"]["model"]
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map=distributed_state.device, torch_dtype=torch.float16
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    # Need to set the padding token to the eos token for generation
    tokenizer.pad_token = tokenizer.eos_token
    
    data_list = read_json_objects(config["dataset"]["input_path"])
    prompt = config["inference"].get("prompt", 'Solve the following question step by step and wrap your final answer in \\boxed{{}}.\n\n')
  
    # You can change the batch size depending on your GPU RAM
    batch_size = config["inference"].get("batch_size", 8)

    # # Generation config
    # generation_config = GenerationConfig(
    #     num_return_sequences = config['inference']['generate_n']
    # )

    generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map=distributed_state.device, batch_size=batch_size)

    # Split into batches
    # formatted_chats = [chats[i : i + batch_size] for i in range(0, len(chats), batch_size)]

    completions_per_process = []
    # We automatically split the batched data we passed to it across all the processes. We also set apply_padding=True
    # so that the GPUs will have the same number of prompts, and you can then gather the results.
    # For example, if we have 2 gpus, the distribution will be:
    # GPU 0: ["I would like to", "hello how are you"],  "what is going on", "roses are red and"]
    # GPU 1: ["welcome to the hotel"], ["welcome to the hotel"] -> this prompt is duplicated to ensure that all gpus have the same number of prompts
    with distributed_state.split_between_processes(data_list, apply_padding=True) as splitted_data_list:
        instructions = [prompt + d['question'] for d in splitted_data_list]
        answers = [d['answer'] for d in splitted_data_list]

        # chats = [
        #     [{'role': 'user', 'content': inst}] for inst in instructions
        # ]

        logging.debug(f"Process index: {distributed_state.local_process_index}, processing {len(instructions)} samples.")

        # We generate the text, decode it and add it to the list completions_per_process
        responses = generator(instructions,
                        temperature=config["inference"].get("temperature", 0.2),
                        max_new_tokens=config["inference"].get("max_new_tokens", 1024),
                        num_return_sequences=config['inference'].get("generate_n", 8)
                    )

        for ins, answer, response in zip(instructions, answers, responses):
            generated_text = response[0]['generated_text'][len(ins):].strip()
            completions_per_process.append({
                "prompt": ins,
                "answer": answer,
                "output": generated_text
            })

    # We are gathering string, so we need to use gather_object.
    # If you need to gather tensors, you can use gather from accelerate.utils
    completions_gather = gather_object(completions_per_process)

    # Drop duplicates produced by apply_padding in split_between_processes
    completions = completions_gather[: len(data_list)]
    # distributed_state.print(completions)

    if distributed_state.is_main_process:
        logging.debug(f"len of completions: {len(completions)}")
        create_parent_directory(config["dataset"]["output_path"])
        write_data_to_json_file(completions, config["dataset"]["output_path"])



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))

    main(config)