from functools import partial
import os
from pathlib import Path
import re
import string
import sys
import yaml
from argparse import ArgumentParser

import torch
from omegaconf import OmegaConf
from configs import ScriptArguments
from data import PromptAnswerDataCollator, get_dataset_display_name, get_train_dataset, get_eval_dataset
from eval import compute_metrics
from character_tokenizer import CharacterTokenizer

from typing import Tuple, Union, cast, Dict, Any
from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoConfig, PreTrainedModel, AutoTokenizer, PreTrainedTokenizer, ByT5Tokenizer, LlamaConfig, LlamaForCausalLM, GenerationConfig
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer, DataCollatorForCompletionOnlyLM

from trainers import trainer_class_factory
from verifiers import ground_truth_verifier

# torch._dynamo.config.optimize_ddp=False

def get_args():
    top_parser = ArgumentParser()
    top_parser.add_argument("--args", nargs='+', help="List of YAML config files for model arguments")
    top_parser.add_argument("--train_args", nargs='+', help="List of YAML config files for training arguments")
    all_args, override_args = top_parser.parse_known_args()
    
    # Load and merge model arguments using OmegaConf
    args_conf = OmegaConf.create()
    for config_file in all_args.args:
        conf = OmegaConf.load(config_file)
        args_conf = OmegaConf.merge(args_conf, conf)
    
    # Determine training algorithm config class
    train_algo = OmegaConf.select(args_conf, "train_algo")
    if train_algo == 'SFT' or train_algo is None:
        config_class = SFTConfig
    elif train_algo == 'GRPO':
        config_class = GRPOConfig
    
    # Load and merge training arguments using OmegaConf
    train_args_conf = OmegaConf.create()
    for config_file in all_args.train_args:
        conf = OmegaConf.load(config_file)
        train_args_conf = OmegaConf.merge(train_args_conf, conf)
    
    # Add hardware-specific configuration
    if torch.cuda.get_device_capability(0)[0] < 8:
        train_args_conf.tf32 = False
    
    # Simple approach: Parse all override args at once
    if override_args:
        # Convert override args to dotlist format by removing leading '--'
        clean_overrides = []
        for arg in override_args:
            if arg.startswith("--"):
                clean_overrides.append(arg[2:])
            else:
                clean_overrides.append(arg)
        
        # Convert to OmegaConf structure
        override_conf = OmegaConf.from_dotlist(clean_overrides)
        
        # Create instance of each config class to get field names
        script_args_instance = ScriptArguments()

        # Split overrides between model args and training args
        model_override_conf = OmegaConf.create()
        train_override_conf = OmegaConf.create()

        # Iterate through top-level keys in override_conf
        for key in override_conf.keys():
            if hasattr(script_args_instance, key):
                OmegaConf.update(model_override_conf, key, override_conf[key])
            else:
                OmegaConf.update(train_override_conf, key, override_conf[key])
        
        # Merge overrides with main configs
        args_conf = OmegaConf.merge(args_conf, model_override_conf)
        train_args_conf = OmegaConf.merge(train_args_conf, train_override_conf)
    
    # Convert to dict and parse with HfArgumentParser
    args_dict = OmegaConf.to_container(args_conf, resolve=True)
    args = HfArgumentParser(ScriptArguments).parse_dict(args_dict)[0]
    args = cast(ScriptArguments, args)
    
    train_args_dict = OmegaConf.to_container(train_args_conf, resolve=True)
    train_args = HfArgumentParser(config_class).parse_dict(train_args_dict)[0]
    train_args = cast(Union[SFTConfig, GRPOConfig], train_args)
    set_seed(train_args.seed)
    
    return args, train_args

def get_tokenizer(args: ScriptArguments):
    # We don't pad when generating the datasets
    # During eval, the inputs are padded on the left and the labels are padding on the right using custom data collator
    if args.use_character_tokenizer:
        tokenizer = CharacterTokenizer(string.ascii_letters + string.digits + string.punctuation + ' ')
        tokenizer.padding_side == 'left'

        if args.add_special_tokens > 0:
            added_tokens = [f'[{str(i)}]' for i in range(args.add_special_tokens)] + ['[X]']
            num_added = tokenizer.add_tokens(added_tokens, special_tokens=True)

        if args.model_id is not None:
            old_tokenizer = AutoTokenizer.from_pretrained(args.model_id, token=os.environ.get("HF_TOKEN", None), revision=args.revision)
            if args.add_special_tokens > 0:
                old_tokenizer.add_tokens(added_tokens, special_tokens=True)
            overlap_special_keys = [k for k in old_tokenizer.special_tokens_map_extended.keys() if k in tokenizer.special_tokens_map_extended]
            all_vocab = [k for k,v in tokenizer.get_vocab().items() if v >= 0 and k in old_tokenizer.get_vocab()] # filter out the -100 token, which will never appear in the input
            for voc in all_vocab:
                tokenizer._vocab_str_to_int[voc] = old_tokenizer.convert_tokens_to_ids(voc)
            tokenizer._vocab_int_to_str = {v: k for k, v in tokenizer._vocab_str_to_int.items()}
            tokenizer.add_special_tokens({k: getattr(old_tokenizer, k) for k in overlap_special_keys})
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_id, token=os.environ.get("HF_TOKEN", None), revision=args.revision)
        tokenizer.padding_side = 'left'
        # tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

def get_all_datasets(args: ScriptArguments, train_args: SFTConfig | GRPOConfig, tokenizer: PreTrainedTokenizer):
    train_dataset, eval_datasets = None, None

    if train_args.do_eval:
        eval_datasets, unmapped_eval_datasets = get_eval_dataset(args, tokenizer)
    if train_args.do_train:
        train_dataset = get_train_dataset(args, train_args, tokenizer, no_sample_from=unmapped_eval_datasets)
    else:
        train_dataset = Dataset.from_dict({})

    tokenizer.padding_side = 'left' # in case it was changed by the data generator
    return train_dataset, eval_datasets

def get_non_hf_model_id(args: ScriptArguments):
    return f"{args.architecture}-{args.hidden_size}-{args.num_attention_heads}-{args.num_layers}-{args.max_position_embeddings}"

def get_model(args: ScriptArguments, train_args: SFTConfig | GRPOConfig, tokenizer: PreTrainedTokenizer):
    if args.model_id is not None:
        if args.from_pretrained:
            if args.use_unsloth:
                from unsloth import FastLanguageModel
                model: FastLanguageModel
                model, tokenizer = FastLanguageModel.from_pretrained(args.model_id, token=os.environ.get("HF_TOKEN", None), fast_inference=False, load_in_4bit=False)
            else:
                model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.model_id, token=os.environ.get("HF_TOKEN", None), revision=args.revision)
        else:
            config = AutoConfig.from_pretrained(args.model_id, token=os.environ.get("HF_TOKEN", None), revision=args.revision)
            model = AutoModelForCausalLM.from_config(config)
    else:
        if args.architecture.startswith("llama"):
            if 'nope' in args.architecture:
                # Monkey patch the apply_rotary_pos_emb function to disable it
                import transformers.models.llama.modeling_llama
                def apply_rotary_pos_emb_dummy(q, k, *args, **kwargs):
                    return q, k
                transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb_dummy
            
            model_config = LlamaConfig(
                vocab_size=tokenizer.total_vocab_size,
                hidden_size=args.hidden_size,
                intermediate_size=args.intermediate_size,
                num_attention_heads=args.num_attention_heads,
                num_hidden_layers=args.num_layers,
                max_position_embeddings=args.max_position_embeddings,
                # _attn_implementation='flash_attention_2' if train_args.bf16 else 'sdpa',
                _attn_implementation='sdpa',
                rope_theta=args.rope_theta,
                partial_rotary_factor=args.partial_rotary_factor,
                attention_dropout=args.dropout
            )
            model = LlamaForCausalLM(model_config)
            model.to(torch.bfloat16 if train_args.bf16 else torch.float32)
        else:
            raise ValueError(f"Unknown architecture: {args.architecture}")

    if args.use_lora:
        default_lora_config = {
            "task_type": "CAUSAL_LM",
            "r": 32,
            "lora_alpha": 128,
            "lora_dropout": 0.0,
            "bias": "none",
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        }
        extra_lora_config = args.lora_config if args.lora_config is not None else {}
        default_lora_config.update(extra_lora_config)
        if args.add_special_tokens > 0:
            default_lora_config['target_modules'] += ["embed_tokens", "lm_head"]

        if args.use_unsloth:
            default_lora_config.pop('task_type')
            model = FastLanguageModel.get_peft_model(model, **default_lora_config, random_state=train_args.seed)
        else:
            lora_config =  LoraConfig(
                **default_lora_config
            )
            model = get_peft_model(model, peft_config=lora_config)
        
        # if args.add_special_tokens > 0:
        #     # Enable gradients for embedding layer when adding special tokens
        #     old_embedding_size = model.get_input_embeddings().weight.shape[0]
        #     # Resize embedding layer and make the new tokens trainable
        #     model.resize_token_embeddings(old_embedding_size + args.add_special_tokens, pad_to_multiple_of=64)
        #     # Get the embedding layer
        #     embedding_layer = model.get_input_embeddings()
        #     # Enable gradient for the new token embeddings only
        #     embedding_layer.weight.requires_grad = True

        model.print_trainable_parameters()

    if args.freeze_layers is not None:
        for name, param in model.named_parameters():
            if any(f"layers.{layer_idx}." in name for layer_idx in args.freeze_layers):
                param.requires_grad = False
                print(f"Froze {name}")

    print(f"Number of parameters: {model.num_parameters()}")
    print(f"Number of trainable parameters: {model.num_parameters(only_trainable=True)}")

    # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)

    return model

def get_run_name(args: ScriptArguments, train_args: SFTConfig | GRPOConfig):
    run_name = args.run_name_prefix

    if args.model_id is None:
        run_name += f"-{get_non_hf_model_id(args)}"
    else:
        run_name += f"-{args.model_id}"
    if args.from_pretrained:
        run_name += "-pretrained"
    if not args.use_character_tokenizer:
        run_name += "-nochar"
    if args.use_lora:
        run_name += "-lora"
    if args.rope_theta != torch.inf:
        run_name += f"-rope"
    disp_task = '-'.join([get_dataset_display_name(data_args.op, data_args.kwargs) for data_args in args.train_data.values()])
    run_name += f'-{disp_task}'
    run_name += f'-{args.train_algo}'
    run_name += f'-seed-{train_args.seed}'
    translator = str.maketrans('/,', '__', ''.join(set(string.punctuation + string.whitespace) - set('/,_-.=')))
    run_name = str.translate(run_name, translator)

    if not train_args.do_train:
        train_args.run_name += '-eval'

    return run_name

def prepare_train_args(args: ScriptArguments, train_args: SFTConfig | GRPOConfig):
    train_args.run_name = get_run_name(args, train_args)
    if isinstance(train_args.resume_from_checkpoint, str) and train_args.resume_from_checkpoint.lower() == "true":
        train_args.resume_from_checkpoint = True

    train_args.output_dir = f"out/{train_args.run_name}"
    train_args.save_safetensors = False # supposed to fix "There were missing keys in the checkpoint model loaded: ['lm_head.weight']."
    train_args.dataloader_num_workers = args.num_workers
    train_args.dataloader_persistent_workers = False
    train_args.remove_unused_columns = False
    # train_args.eval_do_concat_batches = True
    train_args.log_on_each_node = False
    train_args.ddp_find_unused_parameters = False
    # train_args.use_liger_kernel = True
    train_args.use_liger = True
    # train_args.dataloader_prefetch_factor = 16

    return train_args

def get_trainer(args: ScriptArguments, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, train_args: SFTConfig | GRPOConfig, train_dataset: Dataset, eval_datasets: Dataset):
    if not args.from_pretrained:
        generation_config = GenerationConfig(
            num_beams=1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_length=args.generate_max_length,
        )
        model.generation_config = generation_config
    else:
        model.generation_config.max_length = args.generate_max_length
    
    SFTTrainerWithGenerate, GRPOTrainerWithGenerate = trainer_class_factory(args)

    if args.train_algo == 'SFT':
        # def formatting_func(example):
        #     return example['prompt'] + '=' + example['target']
        trainer = SFTTrainerWithGenerate(
            model=model,
            processing_class=tokenizer, # not used
            args=train_args,
            train_dataset=train_dataset,
            eval_dataset=eval_datasets,
            compute_metrics=partial(compute_metrics, tokenizer=tokenizer, task=args.eval_data.get('A', args.eval_data.get('B')).op),
            # data_collator=DataCollatorForCompletionOnlyLM(response_template='=', tokenizer=tokenizer),
            # formatting_func=formatting_func
            data_collator=PromptAnswerDataCollator(tokenizer=tokenizer),
        )
    elif args.train_algo == 'GRPO':
        # Another option is to have a null processing class and do the tokenization in data.py
        trainer = GRPOTrainerWithGenerate(
            model=model,
            processing_class=tokenizer,
            args=train_args,
            train_dataset=train_dataset,
            eval_dataset=eval_datasets,
            compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
            reward_funcs=ground_truth_verifier,
            data_collator=PromptAnswerDataCollator(tokenizer=tokenizer)
        )

    return trainer
