import inspect
import os
import shutil
import logging
import transformers
import torch

import torch.distributed as dist

from datetime import datetime
from dataclasses import field, dataclass
from typing import List, Optional
from leanfinder.retriever.utils.util import set_logger, print_args
from leanfinder.retriever.utils.trainer import LoggerCallback, StepCheckpointCallback, DataCollatorForLeanFinderDualLossDataset
from leanfinder.retriever.rlhf_trainer import LeanFinderDPOTrainer
from leanfinder.retriever.dataset import DPODoubleDataset
from transformers import (
    set_seed,
    AutoConfig,
    AutoTokenizer, 
    HfArgumentParser,
    TrainingArguments, 
    AutoModel,
)

from peft import LoraConfig, PeftModel
from trl import  DPOConfig

logger = logging.getLogger()

@dataclass
class DataArguments:
    no_timestamps: bool = field(default=False)

    # data
    train_parquet_file: str = field(default=None)
    train_file_config: str = field(default=None)
    train_dataset: str = field(default=None)
    train_coef: str = field(default=None)
    delete_long_sample: bool = field(default=False)
    dataset_path_list: Optional[List[str]] = field(
        default_factory=list,
        metadata={"help": "List of dataset paths"}
    )
    train_group_size: int = field(default=5)

    max_len: int = field(default=4096)
    preprocessing_num_workers: int = field(default=64)
    
    model_cfg: str = field(default="data/models/dsprover-v1.5-rl")
    adapter_cfg: str = field(default="data/models/dsprover-v1.5-rl")
    tokenizer_cfg: str = field(default="data/models/dsprover-v1.5-rl")
    flash_attention: bool = field(default=False)
    contrastive_loss_temp: float = field(default=0.2)
    rpo_alpha: float = field(default=1.0)

    save_merged_model: bool = field(default=False)
    no_load_model_pararmeters: bool = field(default=False)
    resume_from: str = field(default=None)

    resume_step: int = field(default=None)
    resume_batch_size: int = field(default=None)

    stream: bool = field(default=False)

    step_save_interval: int = field(default=1000)

    external_validation: bool = field(default=False)

@dataclass
class PeftArguments:
    target_modules: str = field(default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj")
    task_type: str = field(default="FEATURE_EXTRACTION")

    lora_r: int = field(default=16)
    lora_alpha: int = field(default=64)
    lora_dropout: float = field(default=0.1)
    bias: str = field(default="none")

def train():
    dist.init_process_group(backend='nccl', init_method='env://')
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    print("Rank", rank, "Current device", torch.cuda.current_device())

    parser = HfArgumentParser((DataArguments, TrainingArguments, PeftArguments))

    data_args, training_args, peft_args = parser.parse_args_into_dataclasses()

    training_args._frozen = False

    if not data_args.no_timestamps:
        timestr = datetime.now().strftime("-%m%d%H%M")
        training_args.output_dir = training_args.output_dir + timestr

    training_args.logging_dir = os.path.join(training_args.output_dir, 'logging')

    if os.path.exists(training_args.output_dir):
        if training_args.overwrite_output_dir:
            print(f"Output directory ({training_args.output_dir}) already exists. Overwriting output dir.")
            if training_args.process_index == 0:
                shutil.rmtree(training_args.output_dir)
        else:
            print(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.")
    
    if training_args.world_size > 1:
        dist.barrier(device_ids=[rank])
    
    if training_args.process_index == 0:
        if not os.path.exists(training_args.output_dir):
            os.makedirs(training_args.output_dir)
    
    if training_args.world_size > 1:
        dist.barrier(device_ids=[rank])
    
    set_seed(training_args.seed)

    node_rank = int(os.getenv('GROUP_RANK', '0'))

    for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]:
        set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log'))

    logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size)

    if training_args.world_size > 1:
        dist.barrier(device_ids=[rank])

    print_args(data_args, 'Data Arguments')
    print_args(training_args, 'Training Arguments')
    print_args(peft_args, 'LoRA Arguments')

    config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True)
    config._attn_implementation = "flash_attention_2"
    config.use_cache = False

    base_model = AutoModel.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
    if base_model.config.pad_token_id is None:
        base_model.config.pad_token_id = 0

    logger.info(base_model)
    base_model.config.use_cache = False
    model = PeftModel.from_pretrained(
        base_model,
        data_args.adapter_cfg,
        torch_dtype=torch.bfloat16,
        is_trainable=True,
        adapter_name='policy',
        )
    model.load_adapter(
        data_args.adapter_cfg,
        torch_dtype=torch.bfloat16,
        adapter_name="reference",
        )

    tokenizer = AutoTokenizer.from_pretrained(data_args.tokenizer_cfg, trust_remote_code=True)
    tokenizer.padding_side = 'right'
    logger.info(f"padding side: {tokenizer.padding_side}")

    if tokenizer.pad_token is None:
        logger.warning("No pad token found, setting pad token to eos token")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id


    logger.info(f"Loading training data from: {data_args.dataset_path_list}, dpo_dataset_path: {data_args.dataset_path_list[0]}, contrastive_dataset_path: {data_args.dataset_path_list[1]}, train_group_size: {data_args.train_group_size}")
    train_sets = DPODoubleDataset(data_args, trainer=None, dpo_dataset_path=data_args.dataset_path_list[0], contrastive_dataset_path=data_args.dataset_path_list[1], train_group_size=data_args.train_group_size)
    
    logger.info('Length of DPO dataset and Contrastive dataset(same length): %d', len(train_sets))

    epoch_checkpoint_callback = StepCheckpointCallback(save_interval=data_args.step_save_interval, output_dir=training_args.output_dir, external_validation=data_args.external_validation)

    data_collator = DataCollatorForLeanFinderDualLossDataset(tokenizer=tokenizer)

    target_modules_str = peft_args.target_modules
    if target_modules_str:
        target_modules_list = [module.strip() for module in target_modules_str.split(',')]
    else:
        target_modules_list = None

    lora_config = LoraConfig(
        target_modules=target_modules_list,
        task_type=peft_args.task_type,
        r=peft_args.lora_r,
        lora_alpha=peft_args.lora_alpha,
        lora_dropout=peft_args.lora_dropout,
        bias=peft_args.bias,
    )

    print(lora_config)

    training_args_main_output_dir = training_args.output_dir
    training_args.output_dir = os.path.join(training_args.output_dir, f"backups")

    dpo_config_params = inspect.signature(DPOConfig).parameters
    relevant_args = {k: v for k, v in vars(training_args).items() if k in dpo_config_params}
    relevant_args.update({
        'ref_adapter_name': "reference",
        'model_adapter_name': "policy",
        'beta': 0.1,
        'loss_type': "sigmoid",
        'rpo_alpha': data_args.rpo_alpha, 
        'sync_ref_model': False, 
        'ref_model_sync_steps': 4,
        'force_use_ref_model': False,
        'dataloader_num_workers': 0,
    })

    dpo_config = DPOConfig(**relevant_args)
    logger.info(dpo_config)

    dpo_trainer = LeanFinderDPOTrainer(
        model=model,
        args=dpo_config,
        train_dataset=train_sets,
        processing_class=tokenizer,
        data_collator=data_collator,
        dual_loss=True,
        contrastive_loss_temp=data_args.contrastive_loss_temp,
        callbacks=[LoggerCallback, epoch_checkpoint_callback],
    )

    epoch_checkpoint_callback.set_trainer(dpo_trainer)
    train_sets.set_trainer(dpo_trainer)

    dpo_trainer.train(resume_from_checkpoint=data_args.resume_from)
    
    dpo_trainer.save_model(os.path.join(training_args_main_output_dir, "final"))

if __name__ == "__main__":

    try:
        train()
    except Exception as e:
        logging.exception(e)
        exit(-1)
