import os
import torch
import torch.nn.functional as F
from torch.nn import KLDivLoss
import gc
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import (
    Trainer,
    TrainingArguments,
    AutoModelForCausalLM,
    AutoTokenizer,
)
import argparse
from typing import Dict, List, Tuple

def load_json_data(data_path: str):
    return load_dataset('json', data_files=data_path, split='train')

class KALEKLRationaleDataset(Dataset):
    def __init__(self, data, tokenizer: AutoTokenizer, max_length: int):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.raw_data = data
        self.samples = self._prepare_samples()

    def _prepare_samples(self) -> List[Dict]:
        samples = []
        for example in self.raw_data:
            system_prompt = example.get('system', '')
            user_input = example.get('input', '')
            rationale = example.get('rationale', '').strip()
            answer = example.get('output', '').strip()

            if not answer:
                continue

            prompt_without_rationales_text = (
                f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
                f"<|im_start|>user\n{user_input}<|im_end|>\n"
                f"<|im_start|>assistant\nAnswer: {answer}<|im_end|>"
            )
            
            entry = {
                'prompt_without_rationales_text': prompt_without_rationales_text,
                'answer_text': answer,
                'has_rationale_pair': False
            }

            if rationale:
                prompt_with_target_rationales_text = (
                    f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
                    f"<|im_start|>user\n{user_input}<|im_end|>\n"
                    f"<|im_start|>assistant\n{rationale}\nAnswer: {answer}<|im_end|>"
                )
                entry['prompt_with_target_rationales_text'] = prompt_with_target_rationales_text
                entry['has_rationale_pair'] = True
            
            samples.append(entry)
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def _tokenize_and_prepare_labels(self, prompt_text: str, answer_text: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        tokenized_prompt = self.tokenizer(
            prompt_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = tokenized_prompt['input_ids'].squeeze(0)
        attention_mask = tokenized_prompt['attention_mask'].squeeze(0)
        
        labels = input_ids.clone()
        
        answer_token_ids_list = self.tokenizer.encode(answer_text, add_special_tokens=False, padding=False, truncation=False)
        
        if not isinstance(answer_token_ids_list, list):
            answer_token_ids_list = answer_token_ids_list.tolist() if hasattr(answer_token_ids_list, 'tolist') else [answer_token_ids_list]
        
        answer_token_ids_tensor = torch.tensor(answer_token_ids_list, device=input_ids.device)

        found_answer_start_idx = -1
        min_len_check = len(answer_token_ids_tensor)

        if min_len_check > 0:
            for i in range(input_ids.size(0) - min_len_check + 1):
                if torch.equal(input_ids[i : i + min_len_check], answer_token_ids_tensor):
                    found_answer_start_idx = i
        
        if found_answer_start_idx != -1:
            labels[:found_answer_start_idx] = -100
            labels[found_answer_start_idx + min_len_check:] = -100
        else:
            labels[:] = -100

        labels[input_ids == self.tokenizer.pad_token_id] = -100
        
        return input_ids, attention_mask, labels

    def __getitem__(self, idx: int) -> Dict[str, any]:
        sample = self.samples[idx]
        answer_text = sample['answer_text']

        input_ids_wr, attention_mask_wr, labels_wr = self._tokenize_and_prepare_labels(
            sample['prompt_without_rationales_text'], answer_text
        )
        
        item = {
            'input_ids_without_rationales': input_ids_wr,
            'attention_mask_without_rationales': attention_mask_wr,
            'labels_without_rationales': labels_wr,
            'has_rationale_pair': torch.tensor(sample['has_rationale_pair'], dtype=torch.bool)
        }

        if sample['has_rationale_pair']:
            input_ids_wtr, attention_mask_wtr, labels_wtr = self._tokenize_and_prepare_labels(
                sample['prompt_with_target_rationales_text'], answer_text
            )
            item['input_ids_with_target_rationales'] = input_ids_wtr
            item['attention_mask_with_target_rationales'] = attention_mask_wtr
            item['labels_with_target_rationales'] = labels_wtr
        else:
            item['input_ids_with_target_rationales'] = torch.zeros_like(input_ids_wr) 
            item['attention_mask_with_target_rationales'] = torch.zeros_like(attention_mask_wr)
            item['labels_with_target_rationales'] = torch.full_like(labels_wr, -100)

        return item

def kl_rationale_data_collator(features: List[Dict]) -> Dict[str, torch.Tensor]:
    batch = {}
    
    batch['input_ids_without_rationales'] = torch.stack([f['input_ids_without_rationales'] for f in features])
    batch['attention_mask_without_rationales'] = torch.stack([f['attention_mask_without_rationales'] for f in features])
    batch['labels_without_rationales'] = torch.stack([f['labels_without_rationales'] for f in features])
    
    batch['input_ids_with_target_rationales'] = torch.stack([f['input_ids_with_target_rationales'] for f in features])
    batch['attention_mask_with_target_rationales'] = torch.stack([f['attention_mask_with_target_rationales'] for f in features])
    batch['labels_with_target_rationales'] = torch.stack([f['labels_with_target_rationales'] for f in features]) 

    batch['has_rationale_pair'] = torch.stack([f['has_rationale_pair'] for f in features])
    return batch

class KLDivergenceTrainer(Trainer):
    def __init__(self, *args, model_frozen_params=None, kl_weight=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_frozen_params = model_frozen_params
        self.kl_weight = kl_weight
        if self.model_frozen_params:
            self.model_frozen_params.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        loss_kl = torch.tensor(0.0, device=model.device)
        outputs_for_return = None
        
        has_rationale_mask = inputs['has_rationale_pair']
        num_kl_samples = has_rationale_mask.sum().item()

        if num_kl_samples > 0 and self.model_frozen_params is not None:
            input_ids_wr_kl = inputs['input_ids_without_rationales'][has_rationale_mask]
            attention_mask_wr_kl = inputs['attention_mask_without_rationales'][has_rationale_mask]
            labels_wr_kl = inputs['labels_without_rationales'][has_rationale_mask]

            input_ids_wtr_kl = inputs['input_ids_with_target_rationales'][has_rationale_mask]
            attention_mask_wtr_kl = inputs['attention_mask_with_target_rationales'][has_rationale_mask]
            labels_wtr_kl = inputs['labels_with_target_rationales'][has_rationale_mask]

            outputs_without_rationales = model(
                input_ids=input_ids_wr_kl,
                attention_mask=attention_mask_wr_kl
            )
            logits_wr = outputs_without_rationales.logits
            outputs_for_return = outputs_without_rationales

            with torch.no_grad():
                outputs_with_target_rationales = self.model_frozen_params(
                    input_ids=input_ids_wtr_kl,
                    attention_mask=attention_mask_wtr_kl
                )
            logits_wtr = outputs_with_target_rationales.logits
            
            kl_components_sum = torch.tensor(0.0, device=model.device)
            total_active_tokens_for_kl = 0

            for i in range(num_kl_samples):
                s_labels = labels_wr_kl[i]
                t_labels = labels_wtr_kl[i]

                s_answer_token_indices = (s_labels != -100).nonzero(as_tuple=True)[0]
                t_answer_token_indices = (t_labels != -100).nonzero(as_tuple=True)[0]
                
                if len(s_answer_token_indices) > 0 and len(s_answer_token_indices) == len(t_answer_token_indices):
                    if torch.all(s_labels[s_answer_token_indices] == t_labels[t_answer_token_indices]):
                        log_probs_wr = F.log_softmax(logits_wr[i, s_answer_token_indices, :], dim=-1)
                        log_probs_wtr = F.log_softmax(logits_wtr[i, t_answer_token_indices, :], dim=-1)
                        
                        kl_divergence_per_position = F.kl_div(log_probs_wr, log_probs_wtr, reduction='none', log_target=True).sum(dim=-1)
                        kl_components_sum += kl_divergence_per_position.sum() 
                        total_active_tokens_for_kl += len(s_answer_token_indices)
            
            if total_active_tokens_for_kl > 0:
                loss_kl = kl_components_sum / total_active_tokens_for_kl
            else:
                 loss_kl = torch.tensor(0.0, device=model.device)
        
        total_loss = loss_kl
        
        if return_outputs:
            if outputs_for_return is None:
                 outputs_without_rationales_dummy = model(
                    input_ids=inputs['input_ids_without_rationales'],
                    attention_mask=inputs['attention_mask_without_rationales']
                )
                 outputs_for_return = outputs_without_rationales_dummy

            return (total_loss, {"loss_kl": loss_kl, "outputs_model_to_train": outputs_for_return})
        return total_loss

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--train_data_path', type=str, required=True)
    parser.add_argument('--cache_dir', type=str, default=None)
    parser.add_argument('--acc_token', type=str, default=None)
    parser.add_argument('--max_length', type=int, default=512)
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=8) 
    parser.add_argument('--grad_accum', type=int, default=1)
    parser.add_argument('--lr', type=float, default=2e-5) 
    parser.add_argument('--kl_weight', type=float, default=1.0)
    parser.add_argument('--deepspeed_config', type=str, default=None)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument('--bf16', action='store_true', default=False) 

    args = parser.parse_args()

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    if not args.acc_token and 'HF_TOKEN' not in os.environ:
        pass

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, token=args.acc_token, cache_dir=args.cache_dir, trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model_dtype = torch.bfloat16 if args.bf16 else torch.float16

    model_to_train_params = AutoModelForCausalLM.from_pretrained(
        args.model_name, torch_dtype=model_dtype,
        token=args.acc_token, cache_dir=args.cache_dir, low_cpu_mem_usage=True, trust_remote_code=True
    )
    
    if hasattr(model_to_train_params, 'gradient_checkpointing_enable'):
        model_to_train_params.gradient_checkpointing_enable()

    model_frozen_params = AutoModelForCausalLM.from_pretrained(
        args.model_name, torch_dtype=model_dtype,
        token=args.acc_token, cache_dir=args.cache_dir, low_cpu_mem_usage=True, trust_remote_code=True
    )
    for param in model_frozen_params.parameters():
        param.requires_grad = False
    model_frozen_params.eval()

    raw_train_data = load_json_data(args.train_data_path)
    train_dataset = KALEKLRationaleDataset(raw_train_data, tokenizer, args.max_length)

    world_size = int(os.getenv('WORLD_SIZE', 1))
    effective_batch_size = args.batch_size * args.grad_accum * world_size
    if len(train_dataset) == 0:
        num_update_steps_per_epoch = 0
    else:
        num_update_steps_per_epoch = max(1, len(train_dataset) // effective_batch_size)

    total_train_steps = args.epochs * num_update_steps_per_epoch
    warmup_steps = max(100, int(0.1 * total_train_steps)) if total_train_steps > 0 else 0
    logging_steps = max(1, num_update_steps_per_epoch // 10) if num_update_steps_per_epoch > 0 else 1


    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        warmup_steps=warmup_steps,
        weight_decay=0.01,
        logging_dir=os.path.join(args.output_dir, 'logs'),
        logging_steps=logging_steps,
        save_total_limit=3,
        eval_strategy="no",          
        save_strategy="epoch",          
        fp16=not args.bf16, 
        bf16=args.bf16,
        learning_rate=args.lr,
        max_grad_norm=1.0, 
        save_safetensors=True,
        push_to_hub=False,
        remove_unused_columns=False, 
        deepspeed=args.deepspeed_config,
        local_rank=args.local_rank if args.local_rank != -1 else int(os.getenv('LOCAL_RANK', '0')),
        group_by_length=True,
        report_to="tensorboard", 
        ddp_find_unused_parameters=None 
    )

    trainer = KLDivergenceTrainer(
        model=model_to_train_params,
        model_frozen_params=model_frozen_params,
        kl_weight=args.kl_weight,
        args=training_args,
        data_collator=kl_rationale_data_collator,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )

    if model_to_train_params.config.use_cache:
        model_to_train_params.config.use_cache = False
    if model_frozen_params.config.use_cache:
        model_frozen_params.config.use_cache = False
    
    try:
        trainer.train()
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        gc.collect()
        trainer.train() 

    trainer.save_model()  
    tokenizer.save_pretrained(training_args.output_dir)
    torch.save(args, os.path.join(training_args.output_dir, 'run_args.bin'))
    torch.save(training_args, os.path.join(training_args.output_dir, 'hf_training_args.bin'))

if __name__ == "__main__":
    main()