import torch, unicodedata, sys, os
import torch.nn as nn
import torch.nn.functional as F
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator, Dict
from dataclasses import dataclass, field
from transformers.cache_utils import Cache
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM,Qwen2RMSNorm
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
from transformers import AutoConfig, AutoModel
from modelscope import AutoTokenizer
from transformers.processing_utils import Unpack
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
from transformers import TrainingArguments, Trainer, PreTrainedModel, Qwen3Model
from transformers.utils import ModelOutput
from peft import LoraConfig, get_peft_model, TaskType
from hf_qwen3_gate import Qwen3Gating, contains_cj, Qwen3MoeGating
import torch.distributed as dist


base_path = os.environ['BASEMODEL_PATH']
print("gate_train Base model is", base_path)
token_norm_str = os.environ.get('TOKEN_NORM', 'false').strip().lower()
should_token_norm = token_norm_str == 'true'
print('should_token_norm', should_token_norm)

def setup_distributed():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)  # CRITICAL: Must be called BEFORE model initialization
    dist.init_process_group(backend="nccl")
    return local_rank

# Call this BEFORE loading your model
local_rank = setup_distributed()


# Define data collator
@dataclass
class CustomDataCollator:
    tokenizer: Any
    padding: Union[bool, str] = "max_length"
    max_length: Optional[int] = 2048
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # input_ids = [f["repeat_response_token_ids"][:self.max_length] for f in features]
        # input_ids2 = [f["token_ids"][:self.max_length] for f in features]
        # print([len(x) for x in input_ids])
        # print([len(x) for x in input_ids2])
        input_ids = [f["token_ids"][:self.max_length] for f in features]

        # Use built-in tokenizer padding
        self.tokenizer.padding_side  = 'left'
        batch = self.tokenizer.pad(
            {"input_ids": input_ids},
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        return batch

if __name__ == "__main__":
    import yaml
    # Open and load the YAML file
    with open("./deepspeed_config.yaml", "r") as file:
        deepspeed_config = yaml.safe_load(file)
    print(deepspeed_config)
    from datetime import datetime
    now = datetime.now()
    if len(sys.argv) >= 3:
        # top_k = int(sys.argv[1])
        # top_p = int(sys.argv[2])
        pass
    else:
        top_k = 20
        top_p = 0.95
    formatted_str = now.strftime("%Y-%m-%d-%H:%M:%S")
    training_args = TrainingArguments(
        output_dir=f"./models/gate-qwen3-controlfix-{top_k}k_{int(top_p * 100)}p_flores_{formatted_str}",
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=deepspeed_config['deepspeed_config']['train_micro_batch_size_per_gpu'],
        per_device_eval_batch_size=deepspeed_config['deepspeed_config']['train_micro_batch_size_per_gpu'],
        learning_rate=2e-5,
        save_strategy="epoch",
        logging_dir="./logs",
        logging_steps=10,
        eval_strategy="epoch",
        report_to="tensorboard",
        push_to_hub=False,
        # fp16=True,
        bf16=True,
        gradient_accumulation_steps=deepspeed_config['deepspeed_config']['gradient_accumulation_steps'],
        warmup_steps=100,
        weight_decay=0.01,
        dataloader_num_workers=4,
        remove_unused_columns=False,
        save_total_limit=2,
    )
    tokenizer = AutoTokenizer.from_pretrained(base_path)
    print("before load model")
    model = Qwen3MoeGating.from_pretrained(
        base_path,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        torch_dtype="auto",
        output_hidden_states=True
    )
    print("after load model")
    model.should_token_norm = should_token_norm
    model.top_k = top_k
    model.top_p = top_p
    model.generation_config.return_dict_in_generate = True
    for name, param in model.named_parameters():
        if 'code_switch' not in name:
            param.requires_grad = False
    print('after loading model')
    # Load dataset
    data_files = {
        "train": "./data/cs_cj_filter_k_train_2025-07-01-14_30_20_human_filter_en.jsonl",
        "validation": "./data/cs_cj_filter_k_test_2025-07-01-14_30_20_human_filter_en.jsonl"
    }
    print("will load dataset")
    train_dataset = pd.read_json(data_files['train'], lines=True)
    flores_train = pd.read_json('./data/flores_train.jsonl', lines=True)
    train_dataset = pd.concat([train_dataset, flores_train]).sample(frac=1)
    train_dataset = train_dataset[train_dataset['string_offsets'].isna()].to_dict('records')
    test_dataset = pd.read_json(data_files['validation'], lines=True).to_dict('records')
    print("finish load dataset")
    # Define trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=CustomDataCollator(tokenizer),
        tokenizer=tokenizer,
    )
    # Start training
    trainer.train()
    trainer.save_model()