import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import json
import os
from tqdm import tqdm
import torch.nn.functional as F


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = ''
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

class CategoryDataset(Dataset):
    def __init__(self, root_dir, tokenizer, category, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        with open(os.path.join(root_dir, f"{category}.json")) as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        prompt = f"Harmful Instruction: {item['goal']}\nSafe Response: {item['response']}"
        encoding = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'reference_response': item['response']
        }

def create_models(rank, model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    policy_model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
    ).to(rank)

    reference_model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16
    ).to(rank)
    reference_model.requires_grad_(False)

    return policy_model, reference_model, tokenizer

def compute_reward(policy_logits, reference_logits, responses, beta=1.0):
    policy_probs = F.softmax(policy_logits, dim=-1)
    reference_probs = F.softmax(reference_logits, dim=-1)
    kl_div = F.kl_div(policy_probs.log(), reference_probs, reduction='batchmean')

    reward = (policy_logits.mean() - reference_logits.mean()) / beta - kl_div
    return reward

def test_category(rank, world_size, config, category):
    setup(rank, world_size)

    policy_model, reference_model, tokenizer = create_models(rank, config['model_path'])
    policy_model = DDP(policy_model, device_ids=[rank])
    
    dataset = CategoryDataset(config['data_root'], tokenizer, category)
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )

    policy_model.eval()
    reference_model.eval()

    total_reward = 0
    total_samples = 0

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"Testing {category}", disable=rank!=0)
        for batch_idx, batch in enumerate(progress_bar):
            inputs = batch['input_ids'].to(rank, non_blocking=True)
            masks = batch['attention_mask'].to(rank, non_blocking=True)

            reference_outputs = reference_model.generate(
                inputs,
                max_length=config['max_length'],
                do_sample=True,
                top_p=0.9
            )

            policy_outputs = policy_model(inputs, attention_mask=masks)

            rewards = compute_reward(
                policy_outputs.logits,
                reference_model(inputs).logits,
                generated_responses=reference_outputs,
                beta=config['beta']
            )

            total_reward += rewards.sum().item()
            total_samples += len(rewards)

            progress_bar.set_postfix({'reward': total_reward / total_samples})

    if rank == 0:
        avg_reward = total_reward / total_samples
        print(f"Average reward for {category}: {avg_reward}")
    
    cleanup()

def main():
    config = {
        'model_path': "Llama-3.1-8B-Instruct",
        'data_root': "/path/to/category_data",
        'categories': [],  # Add your categories here
        'batch_size': 8,
        'beta': 1.0, 
        'max_length': 2048,
        'world_size': torch.cuda.device_count()
    }
    
    for category in config['categories']:
        torch.multiprocessing.spawn(
            test_category,
            args=(config['world_size'], config, category),
            nprocs=config['world_size'],
            join=True
        )

if __name__ == "__main__":
    main()
