import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import os
from tqdm import tqdm
from torch.nn import functional as F


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = ''  
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

class CategoryDataset(Dataset):
    def __init__(self, root_dir, tokenizer, category, max_length=2048):

        self.tokenizer = tokenizer
        self.max_length = max_length
        with open(os.path.join(root_dir, f"{category}.json")) as f:
            self.data = json.load(f)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        prompt = f"Harmful Instruction: {item['goal']}\nSafe Response: {item['response']}"
        encoding = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'reference_response': item['response'] 
        }

def create_models(rank, model_path):

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    policy_model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,  
    ).to(rank)

    reference_model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16
    ).to(rank)
    reference_model.requires_grad_(False) 
    
    return policy_model, reference_model, tokenizer

def compute_reward(policy_logits, reference_logits, responses, beta=1.0):

    policy_probs = F.softmax(policy_logits, dim=-1)
    reference_probs = F.softmax(reference_logits, dim=-1)
    kl_div = F.kl_div(policy_probs.log(), reference_probs, reduction='batchmean')
    
    reward = (policy_logits.mean() - reference_logits.mean()) / beta - kl_div
    return reward

def train_category(rank, world_size, config, category):
    setup(rank, world_size)
    
    policy_model, reference_model, tokenizer = create_models(rank, config['model_path'])
    policy_model = DDP(policy_model, device_ids=[rank])
    
    dataset = CategoryDataset(config['data_root'], tokenizer, category)
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )
    
    optimizer = torch.optim.AdamW(
        policy_model.parameters(),
        lr=config['lr'],
        weight_decay=0.1
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=100,
        num_training_steps=len(dataloader)*config['epochs']
    )
    
    for epoch in range(config['epochs']):
        sampler.set_epoch(epoch)
        
        policy_model.train()
        progress_bar = tqdm(dataloader, desc=f"Training {category} | Epoch {epoch}", disable=rank!=0)
        
        for batch_idx, batch in enumerate(progress_bar):
            inputs = batch['input_ids'].to(rank, non_blocking=True)
            masks = batch['attention_mask'].to(rank, non_blocking=True)

            with torch.no_grad():
                reference_outputs = reference_model.generate(
                    inputs,
                    max_length=config['max_length'],
                    do_sample=True,
                    top_p=0.9
                )
            
            policy_outputs = policy_model(inputs, attention_mask=masks)

            rewards = compute_reward(
                policy_outputs.logits,
                reference_model(inputs).logits,
                generated_responses=reference_outputs,
                beta=config['beta']
            )
            

            loss = -rewards.mean()  
            

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            

            progress_bar.set_postfix({'loss': loss.item()})
    

    if rank == 0:
        save_path = os.path.join(config['save_root'], f"shadow_model_{category}")
        policy_model.module.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
    
    cleanup()

def main():
    config = {
        'model_path': "Llama-3.1-8B-Instruct",
        'data_root': "/path/to/category_data",
        'save_root': "./shadow_models",
        'categories': []
        'batch_size': 8,
        'epochs': 5,
        'lr': 3e-5,
        'beta': 1.0, 
        'max_length': 2048,
        'world_size': torch.cuda.device_count()
    }
    
    for category in config['categories']:
        torch.multiprocessing.spawn(
            train_category,
            args=(config['world_size'], config, category),
            nprocs=config['world_size'],
            join=True
        )

if __name__ == "__main__":
    main()