import gc

from utils_for_llm import *
import json
import os
from random import randrange
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset, DatasetDict, Dataset
import pickle
from functools import partial
from tqdm import tqdm
from trl import SFTConfig, SFTTrainer
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
import numpy as np
from peft import LoraConfig
import re
import argparse
from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding
import warnings
import warnings
from accelerate import Accelerator
from accelerate.utils import gather_object
from codebleu import calc_codebleu
import os
import torch.distributed as dist
from datetime import timedelta
import time
import deepspeed

import torch

def print_memory_usage(stage=""):
    allocated = torch.cuda.memory_allocated() / (1024**3)
    reserved = torch.cuda.memory_reserved() / (1024**3)
    print(f"[{stage}] Allocated memory: {allocated:.2f} GB, Reserved memory: {reserved:.2f} GB")


# Ignore all warnings
warnings.filterwarnings("ignore")

os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = "False"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'


if os.getenv('PYCHARM_HOSTED') != '1':
    dist.init_process_group(backend='nccl', timeout=timedelta(hours=6))


# Initialize the Accelerator
accelerator = Accelerator(mixed_precision='bf16')

if accelerator.state.deepspeed_plugin:
    deepspeed_config = accelerator.state.deepspeed_plugin.deepspeed_config
    zero_version = deepspeed_config.get('zero_optimization', {}).get("stage")
    print(zero_version)
else:
    zero_version = -1

parser = argparse.ArgumentParser()
parser.add_argument("--task", default="code_generation", type=str)
parser.add_argument("--train_file", default='./data/ALL_data.json', type=str) # or ./data/test_data.json or ./data/ALL_data.json
parser.add_argument("--save_strategy", default="no", type=str)
parser.add_argument("--model_version", default=3.1, type=float)
parser.add_argument("--model_size", default=8, type=float)
parser.add_argument("--use_lora",action="store_true")
parser.add_argument("--do_train",action="store_true")
parser.add_argument("--do_infer",action="store_true")
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--neftune_noise_alpha", default=None, type=float)
parser.add_argument("--debug",action="store_true")
parser.add_argument("--eval_steps", default=500, type=float)



args = parser.parse_args()
task = args.task # or task_breakdown
model_id = f"/Pretrained_Language_Models/Meta-Llama-{args.model_version}-{args.model_size}B-Instruct"



max_seq_length = 8192
if task == "code_generation":
    format_instruction = format_instruction_with_code
    BATCH_SIZE = 2
    target_col = "line_by_line"
elif task == "task_breakdown":
    format_instruction = format_instruction_without_code
    max_seq_length = 768
    BATCH_SIZE = 2
    target_col = "description"

else:
    raise Exception(f'{task} is not defined.')
if args.use_lora == False:
    BATCH_SIZE = BATCH_SIZE // 2
eval_batch_size = 2 * BATCH_SIZE
# =================================================


# Define prediction function using accelerate with distributed processing
def predict_on_validation_BATCH(model, tokenizer, eval_dataset, batch_size=12, external_data=False):
    model.eval()

    collator = DataCollatorWithPadding(tokenizer=tokenizer)

    eval_inputs = [format_instruction(sample, add_answer=False, external_data=external_data) for sample in eval_dataset]
    tokenized_inputs = [tokenizer(sample, max_length=max_seq_length, padding=True, truncation=True, add_special_tokens=False)
                        for sample in eval_inputs]

    eval_loader = torch.utils.data.DataLoader(tokenized_inputs, batch_size=batch_size, collate_fn=collator, shuffle=False, drop_last=False)

    total_index = 0
    ans = []
    # Initialize progress bar
    progress_bar = tqdm(total=len(eval_loader), desc=f"Process {accelerator.process_index}", leave=False,
                        disable=not accelerator.is_local_main_process)

    for batch in eval_loader:
        with torch.inference_mode():
            inputs = {key: value.to(accelerator.device) for key, value in batch.items() if key != 'labels'}
            outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.pad_token_id,
                                     max_length=max_seq_length,  num_return_sequences=1)

        original_lengths = [len(input_ids) for input_ids in inputs['input_ids']]

        for i, original_length in enumerate(original_lengths):
            response = tokenizer.decode(outputs[i][original_length:], skip_special_tokens=True)

            pattern = r"Thought:([\s\S]*?)(?:Code:|$)"
            match = re.search(pattern, response)

            if match:
                generated_thought = match.group(1).strip()
            else:
                generated_thought = ""
            generated_code_lst = re.findall(r"```python(.*?)(```|$)", response, re.DOTALL)
            generated_code_lst = [sample[0] for sample in generated_code_lst]
            if len(generated_code_lst):
                generated_code_without_comment = remove_comments(generated_code_lst[0])
                generated_code_with_comment = generated_code_lst[0]
            else:
                generated_code_without_comment = ""
                generated_code_with_comment = ""
            if target_col in eval_dataset[total_index]:
                reference_code_without_comment = remove_comments(eval_dataset[total_index][target_col])
                reference_code_with_comment = eval_dataset[total_index][target_col]
            else:
                reference_code_without_comment = ""
                reference_code_with_comment = ""
            if 'description' in eval_dataset[total_index]:
                gold_thought = eval_dataset[total_index]['description']
            else:
                gold_thought = eval_dataset[total_index].get('thought', "")
            query = tokenizer.decode(outputs[i][:original_length], skip_special_tokens=True)
            sample_result = {
                'query': query,
                'gold_code_with_comment': reference_code_with_comment,
                'gold_code_without_comment': reference_code_without_comment,
                'gold_thought' : gold_thought,
                'generated_code_with_comment' : generated_code_with_comment,
                'generated_code_without_comment': generated_code_without_comment,
                'generated_thought': generated_thought,
                'category': eval_dataset[total_index].get('category', "NA"),
                'type': eval_dataset[total_index].get('type', "NA"),
                'apis': eval_dataset[total_index].get('apis', [])
            }
            eval_result = calc_codebleu([reference_code_without_comment], [generated_code_without_comment], lang="python", weights=(0.1, 0.1, 0.4, 0.4), tokenizer=None)
            # for debugging
            # print('eval_result\n', eval_result)
            # print('response\n', response)
            # print('generated_code\n', generated_code)
            #
            sample_result.update(eval_result)
            total_index += 1
            ans.append(sample_result)
        torch.cuda.empty_cache()
        # Update progress bar
        progress_bar.update(1)

    # Close progress bar
    progress_bar.close()

    model.train()
    return ans


class EvaluationCallback(TrainerCallback):
    def __init__(self, eval_dataset, tokenizer, output_path, step_interval=20):
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer
        self.step_interval = step_interval

        self.best_codebleu = -1  # Initialize best CodeBLEU score
        self.best_model_path = os.path.join(output_path, 'best')  # Initialize path to best model

        # Make sure the output directory exists
        os.makedirs(self.best_model_path, exist_ok=True)

    def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % self.step_interval == 0:
            self.evaluate(args, state, control)

    def evaluate(self, args, state: TrainerState, control: TrainerControl):
        # sync GPUs and start the timer
        torch.cuda.empty_cache()
        accelerator.wait_for_everyone()
        start = time.time()
        # Split the data across processes
        with accelerator.split_between_processes(self.eval_dataset) as eval_dataset:
            infer_result = predict_on_validation_BATCH(model, self.tokenizer, eval_dataset, batch_size=eval_batch_size)

        # Gather results from all processes
        infer_result = gather_object(infer_result)
        timediff = time.time() - start

        minutes, seconds = divmod(timediff, 60)
        hours, minutes = divmod(minutes, 60)

        # Only save the results on the main process
        if accelerator.is_main_process:
            print(f"Inference Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
            with open(os.path.join('./output', f'{path}_step{state.global_step}_result.json'), 'w') as fp:
                json.dump(infer_result, fp, indent=4)
            print('Result is dumped to:\n ',os.path.join('./output', f'{path}_step{state.global_step}_result.json'))
            metrics = print_result(infer_result)
            print('=' * 50)

            codebleu = metrics[0]

            # Save model if CodeBLEU is better
            if codebleu > self.best_codebleu:
                self.best_codebleu = codebleu
                print(f"New best CodeBLEU: {self.best_codebleu:.4f}, saving model to {self.best_model_path}")

                # Save the model
                # model.save_pretrained(self.best_model_path)
                # self.tokenizer.save_pretrained(self.best_model_path)
                current_trainer.save_model(output_dir=self.best_model_path)
        torch.cuda.empty_cache()



if __name__ == "__main__":
    use_lora = args.use_lora 
    path = f"Llama{args.model_version}-{args.model_size}B-workflow-{task}"
    if use_lora:
        path += '-LoRa'
    if args.neftune_noise_alpha:
        path += f"-neft{args.neftune_noise_alpha}"
    if 'syn' in args.train_file.lower():
        path += "_syn"
    if 'all' in args.train_file.lower():
        path += '_all'
    model_path = args.load_path if args.load_path else model_id

    with open(args.train_file, 'r') as fp:
        data = json.load(fp)

    with open('./data/dataset_split_keys.json', 'r') as fp:
        dataset_split = json.load(fp)
    data = [sample for sample in data if (sample['key'] in stat.keys() or sample['key']  == 'synthesized_data')] 
    if accelerator.is_main_process:
        print('use_lora:', use_lora)
        print('path:', path)
        print(f'batch_size:{BATCH_SIZE} eval_batch_size:{eval_batch_size}')
        print(f'max_seq_length:{max_seq_length}')
        print(f'Load from: {model_path}')
        print('Len(data):', len(data))

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # tokenized_data = []
    # for sample in data:
    #     tokenized_data.append(tokenizer.tokenize(format_instruction(sample, add_answer=True)))
    # lengths = [len(_) for _ in tokenized_data]
    # print(np.percentile(lengths, 80))
    #
    # print(np.percentile(lengths, 90))
    #
    # print(np.percentile(lengths, 95))
    # ==================检查序列长度===============

    data = pd.DataFrame(data)

    train_keys_seed = set(dataset_split['train'])
    train_keys_syn = {'synthesized_data'} 

    dev_keys = set(dataset_split['dev'])
    test_keys = set(dataset_split['test'])


    train_df_seed = data[data['key'].isin(train_keys_seed)]
    train_df_syn = data[data['key'].isin(train_keys_syn)]

    val_df = data[data['key'].isin(dev_keys)]
    test_df = data[data['key'].isin(test_keys)]

    # =======debug======
    if args.debug:
        train_df_seed = train_df_seed.head(100)
        train_df_syn = train_df_syn.head(100)
        val_df = val_df.head(10)
        test_df = test_df.head(10)
        args.eval_steps = 1000
        max_seq_length = 8192
        path = 'DEBUG_' + path

    # =======debug======
    train_dataset_seed = Dataset.from_pandas(train_df_seed)
    train_dataset_syn = Dataset.from_pandas(train_df_syn)
    val_dataset = Dataset.from_pandas(val_df)
    test_dataset = Dataset.from_pandas(test_df)
    dataset_dict = DatasetDict({
        'train_seed': train_dataset_seed,
        'train_syn' : train_dataset_syn,
        'validation': val_dataset,
        'test': test_dataset
    })

    use_flash_attention = True
    device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        use_cache=False,
        attn_implementation="flash_attention_2",
        device_map={"": accelerator.process_index} if zero_version != 3 else None,
    )

    model.config.pretraining_tp = 1

    if args.do_infer and zero_version != 3:
        # sync GPUs and start the timer
        accelerator.wait_for_everyone()
        start = time.time()
        # Split the data across processes
        with accelerator.split_between_processes(dataset_dict['validation']) as eval_dataset:
            infer_result = predict_on_validation_BATCH(model, tokenizer, eval_dataset, batch_size=eval_batch_size)

        # Gather results from all processes
        infer_result = gather_object(infer_result)
        timediff = time.time() - start

        minutes, seconds = divmod(timediff, 60)
        hours, minutes = divmod(minutes, 60)


        # Only save the results on the main process
        if accelerator.is_main_process:
            print(f"Inference Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
            with open(os.path.join('./output', f'{path}_result.json'), 'w') as fp:
                json.dump(infer_result, fp, indent=4)
            print('Init reuslt is dumped to:\n ', os.path.join('./output', f'{path}_result.json'))
            print_result(infer_result)
            print('=' * 50)

    if args.do_train:
        if zero_version == 3 or args.save_strategy == 'epoch':
            callbacks = None
            save_strategy='epoch'
        else:
            callbacks = [EvaluationCallback(dataset_dict['validation'], tokenizer, path, step_interval=args.eval_steps)]
            save_strategy = 'no'

        training_args = SFTConfig(
            output_dir=path,
            num_train_epochs=1,
            per_device_train_batch_size=BATCH_SIZE,
            gradient_accumulation_steps=2,
            gradient_checkpointing=True,
            optim="adamw_torch_fused",
            logging_steps=10,
            save_strategy=save_strategy,
            learning_rate=2e-5,
            bf16=True,
            tf32=True,
            max_grad_norm=0.3,
            warmup_ratio=0.1,
            lr_scheduler_type="linear",
            disable_tqdm=False,
            report_to="tensorboard",
            neftune_noise_alpha=args.neftune_noise_alpha
        )


        trainer_syn = SFTTrainer(
            model=model,
            train_dataset=dataset_dict['train_syn'],
            max_seq_length=max_seq_length,
            tokenizer=tokenizer,
            packing=True,
            formatting_func=format_instruction,
            args=training_args,
            callbacks=callbacks
        )
        current_trainer = trainer_syn
        print_memory_usage("Before training stage 1")

        trainer_syn.train()
        print_memory_usage("After training stage 1")

        model_save_path = os.path.join(path, 'stage1_model')

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            trainer_syn.save_model(output_dir=model_save_path)

        del trainer_syn
        del model
        accelerator.wait_for_everyone()
        torch.distributed.barrier()
        deepspeed.runtime.utils.empty_cache()
        gc.collect()

        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        del accelerator

        accelerator = Accelerator(mixed_precision='bf16')

        accelerator.wait_for_everyone()
        print_memory_usage("After clearing memory before stage 2")


        current_trainer_model = AutoModelForCausalLM.from_pretrained(
            model_save_path,
            torch_dtype=torch.bfloat16,
            use_cache=False,
            attn_implementation="flash_attention_2",
            device_map={"": accelerator.process_index} if zero_version != 3 else None,
        )
        current_trainer_model.config.pretraining_tp = 1

        trainer_seed = SFTTrainer(
            model=current_trainer_model,
            train_dataset=dataset_dict['train_seed'],
            max_seq_length=max_seq_length,
            tokenizer=tokenizer,
            packing=True,
            formatting_func=format_instruction,
            args=training_args,
            callbacks=callbacks
        )
        current_trainer = trainer_seed
        trainer_seed.train()
        model_save_path = os.path.join(path, 'stage2_model')

        if accelerator.is_main_process:
            trainer_seed.save_model(output_dir=model_save_path)