import os
import numpy as np
import random
import shutil
import argparse
import torch.distributed as dist
import torch.multiprocessing as mp
import torch
from tqdm import tqdm
import json
import math
import gc

from datasets import load_dataset
from transformers import pipeline, AutoTokenizer, AutoModel

from graph_of_agents import GraphOfAgents

from utils import chat, truncate, extract_answer_choice, seed_everything, get_modelpath, predict, get_vanilla_prompt_format
from utils import RAG

HF_TOKEN = os.getenv('HF_TOKEN')

def parse_args(args=None):
    #TODO: use separate goa argument (not using coa)
    parser = argparse.ArgumentParser()
    
    # Generic
    parser.add_argument("--save_dir", "-s", type=str, default="lb1_table")
    parser.add_argument('--model_name', type=str, help='Which models to use (refer get_modelpath in utils.py)')
    parser.add_argument('--no_distribute', action='store_true', help="Need to be used for large models")
    parser.add_argument('--prompt_mode', type=str, default='generic', help='Customize prompts (no_ans, no_conflict)')  
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--debug', action='store_true', help="Whether to log inputs/outputs")

    # GoA specifics
    parser.add_argument('--coa', action='store_true', help='Whether to use (GoA/CoA)')
    #TODO: change this into bool arguemnt for contextual embedding / int for cluster size / 

    parser.add_argument('--goa', action='store_true', help='Whether to use (GoA/CoA)')
    parser.add_argument('--goa_cluster_size', type=int, default=4, help="The number of subgraphs for GoA")

    parser.add_argument('--no_answer_tag', action='store_true', help="Whether to instruct the manager to use <answer> tag")
    parser.add_argument('--no_summary_tag', action='store_true', help="Whether to instruct the worker to use <summary> tag")
    
    # (Experimental arguments) => only for exploration (will be deleted afterward)
    parser.add_argument('--ablation_type', type=str, default='None')
    parser.add_argument('--no_context', action='store_true', help="Whether to provide context for the next chunk search")
    parser.add_argument('--part', action='store_true', help="Evaluate on part of the data")
    parser.add_argument('--ablation', action='store_true')
    parser.add_argument('--ablation_add', action='store_true')
    parser.add_argument('--ablation_add2', action='store_true')
    parser.add_argument('--summary', action='store_true')
    parser.add_argument('--multi_qa', action='store_true')
    parser.add_argument('--rag', action='store_true')
    parser.add_argument('--lexical', action='store_true')
    parser.add_argument('--goa_mode', type=str, default='None', help='GoA configuration')

    parser.add_argument('--pipeline_test', action='store_true', help='For testing purpose')
    # 

    return parser.parse_args(args)


def get_pred(args, rank, world_size, data, max_length, temperature, max_gen, prompt_format, dataset, model_name,
             out_path, log_summary_dir, batch_size=1):
    """Run inference more efficiently by batching requests."""

    if rank == -1:
        device = None
        model_kwargs = {
            "torch_dtype": 'auto',
            "device_map": 'auto'
        }
    else:
        device = torch.device(f'cuda:{rank}')
        model_kwargs = {
            "torch_dtype": 'auto',
        }

    if 'qwen' in model_name and max_length > 32000:
        model_kwargs["rope_scaling"] = { "rope_type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768 }


    modeltype = "text-generation"
    modelpath = get_modelpath(model_name)

    pipeline_args = {
        "task": modeltype,
        "model": modelpath,
        "model_kwargs": model_kwargs,
        "token": HF_TOKEN,
        "return_full_text": False,
    }

    if device is not None:
        pipeline_args["device"] = device

    
    if temperature != 0.1:
        raise NotImplementedError("Only temperature=0.1 is supported for now") 
    else: 
        tokenizer_kwargs =  {
                "do_sample": True,
                "temperature": temperature,
                'top_p': 0.9,
                'max_new_tokens': max_gen,
            }

    # Build the client once
    if args.coa or args.goa:
        if args.coa:
            use_coa = True
        else:
            use_coa = False 
        prompt_mode = args.prompt_mode
        client = GraphOfAgents(
            worker_model=modelpath,
            manager_model=modelpath,
            model_kwargs=model_kwargs,
            tokenizer_kwargs=tokenizer_kwargs,
            pipeline_args=pipeline_args,
            chunk_size=max_length,
            dataset=dataset,
            device=device,
            prompt_mode=prompt_mode,
            debug=args.debug,
            goa_mode=args.goa_mode,
            goa_cluster_size=args.goa_cluster_size,
            goa_no_context=args.no_context,
            use_coa=use_coa,
            max_gen=max_gen,
            ablation_type=args.ablation_type,
            summary_tag=(not args.no_summary_tag),
            log_summary_dir=log_summary_dir
        ) 

    else:
        print(f'+++++++++\t Use vanilla method with {modeltype} / {modelpath} +++++++++++++')
        client = pipeline(**pipeline_args)

        if args.rag:
            rag = RAG(device, max_length, client.tokenizer, is_lexical=args.lexical)

        if 'qwen' in model_name and max_length > 32000:
            print("Use QWen with rope scaling")
            print(client.model.config)
        

    with open(out_path, "a", encoding="utf-8") as f:
        # Process data in batches
        batch_data = data

        if args.coa or args.goa:
            # Prepare batch data
            batch_context = []
            batch_input = []
            answers = [] 
            summaries = []
            for json_obj in batch_data:
                context = json_obj['context']
                inp = json_obj['input']
                batch_context.append(context)
                batch_input.append(inp)

                answer, summary = client.process(context, inp, return_summary=True)   

                answers.append(answer)
                summaries.append(summary)

            for idx_ans in range(len(answers)):
                ans = answers[idx_ans]
                summm = summaries[idx_ans]
                json_obj = batch_data[idx_ans]
                print(
                    f"{'=' * 100}\n MANAGER INPUT: {summm[:1000]} \n "
                    f"+++MANAGER PRED: {ans[:1000]} \n "
                    f"+++ANSWER: {json_obj['answers']} {'=' * 100}\n")

                if idx_ans == 2: break
            results = answers

        elif args.rag:
            # Retrieve chunks 
            for json_obj in batch_data:
                inp = json_obj['context']
                question = json_obj['input']
                rag_input = rag.process(inp, question, is_batch=False)
                # overwrite the context 
                json_obj['context'] = rag_input
            prompts = [prompt_format.format(**json_obj) for json_obj in batch_data]
            prompts = [truncate(prompt, client.tokenizer, max_length) for prompt in prompts]

            results = predict(client, prompts, tokenizer_kwargs, model_name, is_batch=True)

        else:
            if max_length == 0:
                prompts = []
                for json_obj in batch_data:
                    json_obj['context'] = ''
                    pt = prompt_format.format(**json_obj)
                    prompts.append(pt)
            else:
                prompts = [prompt_format.format(**json_obj) for json_obj in batch_data]
                prompts = [truncate(prompt, client.tokenizer, max_length) for prompt in prompts]

            results = predict(client, prompts, tokenizer_kwargs, model_name, is_batch=True)

        for json_obj, res in zip(batch_data, results):
            record = {
                "pred": res,
                "answers": json_obj["answers"],
                "all_classes": json_obj["all_classes"],
                "length": json_obj["length"]
            }
            json.dump(record, f, ensure_ascii=False)
            f.write('\n')
    del client 
    gc.collect()
    torch.cuda.empty_cache()

    return 1


if __name__ == '__main__':
    args = parse_args()
    seed_everything(args.seed)
    if args.no_distribute:
        world_size = 1
    else:
        world_size = torch.cuda.device_count()
        mp.set_start_method('spawn', force=True)

    # define your model
    family, max_length, temperature = args.model_name.split('-')
    max_length = int(max_length)
    temperature = float(temperature)

    model_name = args.model_name

    os.makedirs(args.save_dir, exist_ok=True)
    save_dir = os.path.join(args.save_dir, family)
    os.makedirs(save_dir, exist_ok=True)

    if args.coa:
        add_info = f'_coa-{args.prompt_mode}'
        if args.no_answer_tag:
            add_info += '-no_tag'
        if args.part:
            add_info += '-part'

        out_dir = f"{save_dir}/{model_name}-{add_info}"

    elif args.goa:
        add_info = f'_goa-{args.prompt_mode}-{args.goa_cluster_size}'

        if args.no_context:
            add_info += '-no_context'
        if args.ablation and args.ablation_type != 'None':
            add_info += f'-ablation-{args.ablation_type}'

        if args.no_answer_tag:
            add_info += '-no_tag'

        if args.part:
            add_info += '-part'

        out_dir = f"{save_dir}/{model_name}-{add_info}"

    else:
        if args.ablation: raise NotImplementedError("Ablation for vanilla model is not supported now")

        if args.no_answer_tag:
            add_info = '-no_tag'
        else:
            add_info = ''

        if args.rag:
            if args.lexical:
                add_info += '-LexicalRAG'
            else:
                add_info += '-RAG'

        if args.part:
            add_info += '-part'

        out_dir = f"{save_dir}/{model_name}-vanilla-{add_info}"

    os.makedirs(out_dir, exist_ok=True)

    ###########################################################################################
    ###########################################################################################
    ### Define datasets
    # singledoc qa
    single_doc = ["narrativeqa", "qasper", "multifieldqa_en"]
    # multidoc qa
    multi_doc = ["hotpotqa", "2wikimqa", "musique"]
    # summarization
    summary = ["gov_report", "multi_news", "qmsum"]
    # code understanding
    code = ["lcc", "repobench-p"]

    datasets = single_doc + multi_doc
    

    if args.multi_qa:
        datasets = multi_doc
    elif args.ablation:
        if args.ablation_add:
            datasets = ['hotpotqa', 'qasper', '2wikimqa'] 
        elif args.ablation_add2:
            datasets = ["narrativeqa", "musique", "multifieldqa_en"]
        else:
            datasets = ['qasper', '2wikimqa']

    if args.summary:
        datasets = ["gov_report", "multi_news"]

    for dataset in datasets:
        data = load_dataset('THUDM/LongBench', dataset, split='test')

        out_file = os.path.join(out_dir, f"{dataset}-{args.seed}.jsonl")

        log_summary_dir = f"{out_dir}/log-{dataset}-{args.seed}"

        if args.no_answer_tag:
            use_answer_tag = False
            if args.coa:
                raise NotImplementedError("CoA without answer tag is not supported now")
        else:
            use_answer_tag = True
        prompt_format = get_vanilla_prompt_format(dataset, use_answer_tag=use_answer_tag)

        if dataset in summary:
            max_gen = 512
        else:
            max_gen = 128

        if args.part:
            data_all = [data_sample for idx, data_sample in enumerate(data) if idx in parts]
        else:
            data_all = [data_sample for data_sample in data]

        if args.pipeline_test:
            data_all = [data[0]]

        data_subsets = [data_all[i::world_size] for i in range(world_size)]

        if args.no_distribute:
            rank = -1
            get_pred(args, rank, world_size, data_subsets[0], max_length, temperature, \
                            max_gen, prompt_format, dataset, model_name, out_file, log_summary_dir)
        else:
            processes = []
            for rank in range(world_size):
                p = mp.Process(target=get_pred, args=(args, rank, world_size, data_subsets[rank], max_length, temperature, \
                            max_gen, prompt_format, dataset, model_name, out_file, log_summary_dir))
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
