import os
import json
import warnings
import logging
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Any
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

import torch
from torch.nn import CrossEntropyLoss
from datasets import load_dataset
import transformers
import trl
from tqdm import tqdm
from transformers import TrainingArguments

@dataclass
class TrainingConfig:
    model_name: str = field(default="Qwen/Qwen2-7B-Instruct")
    block_size: int = field(default=4096)
    
    train_file_path: Optional[str] = field(default='data/XXX.jsonl')
    
    # W&B 配置
    use_wandb: bool = field(default=True)
    wandb_project: Optional[str] = field(default="s112_fixed")
    wandb_entity: Optional[str] = field(default=None)

    def __post_init__(self):
        if self.use_wandb:
            os.environ['WANDB_PROJECT'] = self.wandb_project
            if self.wandb_entity:
                os.environ['WANDB_ENTITY'] = self.wandb_entity

class WeightedCollator(trl.DataCollatorForCompletionOnlyLM):
    def torch_call(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        if 'token_weight' in examples[0] and examples[0]['token_weight'] is not None:
            weights = [torch.tensor(ex['token_weight'], dtype=torch.float) for ex in examples]

            padded_weights = torch.nn.utils.rnn.pad_sequence(
                weights, batch_first=True, padding_value=1.0
            )
            seq_len = batch['input_ids'].shape[1]
            if padded_weights.shape[1] > seq_len:
                padded_weights = padded_weights[:, :seq_len]
            elif padded_weights.shape[1] < seq_len:
                padding = torch.full(
                    (padded_weights.shape[0], seq_len - padded_weights.shape[1]), 
                    1.0, 
                    dtype=padded_weights.dtype
                )
                padded_weights = torch.cat([padded_weights, padding], dim=1)
                
            batch['token_weight'] = padded_weights.to(batch['input_ids'].device)
            
        return batch

class WeightedSFTTrainer(trl.SFTTrainer):
    def compute_loss(self, model, inputs, return_outputs=False,**kwargs):
        labels = inputs.get("labels")
        weights = None
        if "token_weight" in inputs:
            weights = inputs.pop("token_weight")
        
        outputs = model(**inputs)
        logits = outputs.get("logits")
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        loss_fct = CrossEntropyLoss(reduction='none')
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        token_losses = token_losses.view(shift_labels.shape)
        
        mask = (shift_labels != -100)
        
        unweighted_loss_tensor = token_losses[mask].mean()
        
        # Log
        with torch.no_grad():
            unweighted_loss_log = unweighted_loss_tensor.item() if mask.sum() > 0 else 0.0
            if self.is_in_train:
                self.log({"unweighted_loss_mean": unweighted_loss_log})

        if weights is not None:
            shift_weights = weights[..., 1:].contiguous()
            
            weighted_token_losses = token_losses * shift_weights
            final_loss = weighted_token_losses[mask].mean()
        else:
            final_loss = unweighted_loss_tensor

        return (final_loss, outputs) if return_outputs else final_loss

def train():
    parser = transformers.HfArgumentParser((TrainingConfig, TrainingArguments))
    config, args = parser.parse_args_into_dataclasses()
    args.remove_unused_columns = False 
    
    if config.use_wandb:
        args.report_to = ['wandb']
    else:
        args.report_to = []

    tokenizer = transformers.AutoTokenizer.from_pretrained(config.model_name, use_fast=True)

    if "Llama" in config.model_name:
        instruction_template = "<|start_header_id|>user<|end_header_id|>"
        response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"
        tokenizer.pad_token = "<|reserved_special_token_5|>"
        if tokenizer.eos_token is None: tokenizer.eos_token = "</s>"
    elif "Qwen" in config.model_name:
        instruction_template = "<|im_start|>user"
        response_template = "<|im_start|>assistant\n"
        tokenizer.pad_token = "<|fim_pad|>"
        tokenizer.eos_token = "<|im_end|>"
    else:
        instruction_template = "USER:"
        response_template = "ASSISTANT:"
        if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

    logging.info(f"Pad Token: {tokenizer.pad_token}")
    logging.info(f"EOS Token: {tokenizer.eos_token}")
    args.eos_token = tokenizer.eos_token

    is_distributed = int(os.environ.get("LOCAL_RANK", -1)) != -1
    model_kwargs = {
        "trust_remote_code": True, 
        "use_cache": False if args.gradient_checkpointing else True,
        "torch_dtype": "auto"
    }

    if is_distributed:
        pass
    else:
        model_kwargs["device_map"] = "auto"

    model = transformers.AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
    model.config.pad_token_id = tokenizer.pad_token_id

    if not os.path.exists(config.train_file_path) and not "hf.co" in config.train_file_path:
        raise FileNotFoundError(f"error to find: {config.train_file_path}")
    
    dataset = load_dataset(config.train_file_path)
    
    has_weights = 'train' in dataset and "token_weight" in dataset['train'].features

    def tokenize_and_merge(examples):
        texts = examples["text"]
        
        texts_with_eos = [t + tokenizer.eos_token for t in texts]
        
        tokenized = tokenizer(
            texts_with_eos,
            truncation=True,
            padding=False, 
            max_length=config.block_size,
            add_special_tokens=True, 
        )
        
        if "token_weight" not in examples:
            return tokenized
            
        raw_weights_batch = examples["token_weight"]
        aligned_weights_batch = []
        
        for i, raw_weights in enumerate(raw_weights_batch):
            input_ids = tokenized["input_ids"][i]
            actual_len = len(input_ids)
            
            if raw_weights is None:
                aligned_weights_batch.append([1.0] * actual_len)
                continue
            
            current_weights = raw_weights

            if len(current_weights) == actual_len - 1:
                current_weights = [1.0] + current_weights

            if len(current_weights) > actual_len:
                current_weights = current_weights[:actual_len]
            elif len(current_weights) < actual_len:
                current_weights = current_weights + ([1.0] * (actual_len - len(current_weights)))
            
            aligned_weights_batch.append(current_weights)
            
        tokenized["token_weight"] = aligned_weights_batch
        return tokenized

    column_names = dataset["train"].column_names
    dataset = dataset.map(
        tokenize_and_merge,
        batched=True,
        batch_size=1000,
        num_proc=os.cpu_count() // 2,
        remove_columns=column_names
    )


    collator = WeightedCollator(
        instruction_template=instruction_template,
        response_template=response_template,
        tokenizer=tokenizer,
        mlm=False
    )
    

    
    trainer = WeightedSFTTrainer(
        model=model,
        args=args,
        train_dataset=dataset['train'],
        eval_dataset=dataset.get('test', dataset.get('train')),
        data_collator=collator,
    )

    trainer.train()
    
    logging.info("save...")
    trainer.model.config.eos_token_id = tokenizer.eos_token_id
    trainer.save_model(output_dir=args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    trainer.accelerator.wait_for_everyone()

if __name__ == "__main__":
    torch.distributed.init_process_group(backend="nccl")
    train()
    torch.distributed.destroy_process_group()