import os
from pathlib import Path
import numpy as np
import argparse
import time
import pickle as pkl
import torch
import torch.distributed as dist
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from SFT.anollm import AnoLLM
from SFT.train_anollm import get_run_name
from evaluate.data_utils import load_data, DATA_MAP, get_text_columns, get_max_length_dict


def get_spin_run_name(args):
    name = 'anollm' 
    name += '_lr{:.0e}'.format(args.lr)
    name += '_{}'.format(args.binning)
    
    if args.model == 'HuggingFaceTB/SmolLM-135M': 
        name += '_smolLM'
    elif args.model == 'HuggingFaceTB/SmolLM-360M':
        name += '_smolLM360'
    elif args.model == 'HuggingFaceTB/SmolLM-1.7B':
        name += '_smolLM1.7B'
    else:
        name += '_' + args.model
    
    if args.no_random_permutation:
        name += '_no_random_permutation'	
    
    if args.lora:
        name += '_lora'
    
    name += f'_iter{args.iteration}'
    name += f'_steps{args.max_steps}'
    name += f'_eps{args.epsilon}'
    name += f'_beta{args.beta}'
    name += f'_f{args.f_divergence_type}'
    name += '_spin'
    
    return name


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default='vifd', 
                       choices=[d.lower() for d in DATA_MAP.keys()],
                       help="Name of datasets")
    parser.add_argument("--exp_dir", type=str, default=None)
    parser.add_argument("--setting", type=str, default='semi_supervised',
                       choices=['semi_supervised', 'unsupervised'])
    
    parser.add_argument("--wandb", action='store_true')
    parser.add_argument("--entity", type=str, default=None)
    parser.add_argument("--project", type=str, default='DiSPaT')
    
    parser.add_argument("--data_dir", type=str, default='data')
    parser.add_argument("--n_splits", type=int, default=5)
    parser.add_argument("--split_idx", type=int, default=0)
    parser.add_argument("--train_ratio", type=float, default=0.5)
    parser.add_argument("--seed", type=int, default=42)
    
    parser.add_argument("--binning", type=str,
                       choices=['quantile', 'equal_width', 'language', 'none', 'standard'],
                       default='standard')
    parser.add_argument("--n_buckets", type=int, default=10)
    
    parser.add_argument("--model", type=str,
                       choices=['gpt2', 'distilgpt2', 'smol', 'smol-360', 'smol-1.7b'],
                       default='smol-360')
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--lora", action='store_true', default=False)
    parser.add_argument("--max_steps", type=int, default=1000)
    parser.add_argument("--eval_n_permutations", type=int, default=21,
                       help="Number of permutations for evaluation (after each iteration)")
    parser.add_argument("--eval_batch_size", type=int, default=8,
                       help="Batch size for evaluation")
    parser.add_argument("--no_random_permutation", action='store_true', default=False)
    parser.add_argument("--reference_model_dir", type=str, required=True,
                       help="Directory containing the reference model (p_θt). For iter0, use finetuned model. For iter>0, use SPIN model from previous iteration.")
    parser.add_argument("--spin_data_dir", type=str, required=True,
                       help="Base directory containing generated SPIN datasets (should contain iter0/, iter1/, ... folders)")
    parser.add_argument("--iteration", type=int, default=0,
                       help="SPIN iteration number (0, 1, 2, ...). Iteration 0 uses finetuned model, iteration >0 uses SPIN model from previous iteration.")
    parser.add_argument("--beta", type=float, default=0.1,
                       help="Temperature parameter for SPIN loss")
    parser.add_argument("--epsilon", type=float, default=0.02,
                       help="Epsilon for clamping ratio in rejected samples")
    parser.add_argument("--f_divergence_type", type=str, default='identity',
                       choices=['identity', 'kl', 'reverse_kl', 'squared_hellinger'],
                       help="Type of f-divergence to use. Choices: identity (f*(v)=v), kl (f*(v)=exp(v)-1), reverse_kl (f*(v)=-log(1-v)), squared_hellinger (f*(v)=v/(1-v))")
    
    parser.add_argument("--remove_metadata_20news", action='store_true', default=False,
                       help="Remove headers/footers/quotes for 20news dataset. Default: False.")
    
    args = parser.parse_args()
    
    if args.exp_dir is None:
        args.exp_dir = Path('exp_dispat') / args.dataset / args.setting / f"split{args.n_splits}" / f"split{args.split_idx}"
    else:
        args.exp_dir = Path(args.exp_dir)
    
    if args.model == 'smol':
        args.model = 'HuggingFaceTB/SmolLM-135M'
    elif args.model == 'smol-360':
        args.model = 'HuggingFaceTB/SmolLM-360M'
    elif args.model == 'smol-1.7b':
        args.model = 'HuggingFaceTB/SmolLM-1.7B'
    
    args.save_dir = args.exp_dir / 'models'
    os.makedirs(args.save_dir, exist_ok=True)
    
    return args


def f_star_conjugate(v, f_type='identity', epsilon=0.02):
    if f_type == 'identity':
        return v
    elif f_type == 'kl':
        return torch.exp(v) - 1.0
    elif f_type == 'reverse_kl':
        v_clamped = torch.clamp(v, max=1.0 - epsilon - 1e-8)
        return -torch.log(1.0 - v_clamped + 1e-8)
    elif f_type == 'squared_hellinger':
        v_clamped = torch.clamp(v, max=1.0 - epsilon - 1e-8)
        return v_clamped / (1.0 - v_clamped + 1e-8)
    
    else:
        raise ValueError(f"Unknown f_divergence_type: {f_type}")


def get_batch_logps(model, input_ids, attention_mask, requires_grad=False):
    if requires_grad:
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    else:
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
    
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    shift_attention_mask = attention_mask[:, 1:].contiguous()
    
    loss_fct = CrossEntropyLoss(reduction='none')
    log_probs = -loss_fct(shift_logits.transpose(1, 2), shift_labels)
    batch_log_probs = (log_probs * shift_attention_mask).sum(dim=1) / shift_attention_mask.sum(dim=1)
    
    return batch_log_probs


def spin_loss(policy_model, reference_model, batch_chosen, batch_rejected, beta, epsilon, device, f_divergence_type='identity'):

    policy_chosen_logps = get_batch_logps(
        policy_model,
        batch_chosen['input_ids'].to(device),
        batch_chosen['attention_mask'].to(device),
        requires_grad=True
    )
    
    reference_chosen_logps = get_batch_logps(
        reference_model,
        batch_chosen['input_ids'].to(device),
        batch_chosen['attention_mask'].to(device),
        requires_grad=False
    )
    
    policy_rejected_logps = get_batch_logps(
        policy_model,
        batch_rejected['input_ids'].to(device),
        batch_rejected['attention_mask'].to(device),
        requires_grad=True
    )
    
    reference_rejected_logps = get_batch_logps(
        reference_model,
        batch_rejected['input_ids'].to(device),
        batch_rejected['attention_mask'].to(device),
        requires_grad=False
    )
    
    chosen_log_ratio = policy_chosen_logps - reference_chosen_logps
    rejected_log_ratio = policy_rejected_logps - reference_rejected_logps
    
    rejected_ratio = torch.exp(rejected_log_ratio)
    rejected_ratio_clamped = torch.min(
        rejected_ratio,
        torch.ones_like(rejected_ratio) * (1 - epsilon)
    )
    rejected_log_ratio_clamped = torch.log(rejected_ratio_clamped + 1e-8)
    
    chosen_rho = beta * chosen_log_ratio
    rejected_rho = beta * rejected_log_ratio_clamped
    
    f_star_rejected = f_star_conjugate(rejected_rho, f_type=f_divergence_type, epsilon=epsilon)
    
    logits = chosen_rho - f_star_rejected
    losses = -F.logsigmoid(logits)
    loss = losses.mean()
    
    chosen_rewards = chosen_rho.detach()
    rejected_rewards = rejected_rho.detach()
    
    return loss, chosen_rewards.mean(), rejected_rewards.mean()


def evaluate_model_metrics(
    model,
    X_test,
    y_test,
    args,
    run_name,
    exp_dir,
    device,
    n_permutations=21,
    batch_size=8,
    distributed=False,
    rank=0,
    world_size=1
):
    score_dir = exp_dir / 'scores'
    score_path = score_dir / f"{run_name}.npy"
    raw_score_path = score_dir / f"raw_{run_name}.npy"
    
    if rank == 0:
        os.makedirs(score_dir, exist_ok=True)
        print(f"[Evaluation] Score will be saved to: {score_path}")
    
    model.model.eval()
    with torch.no_grad():
        device_str = "cuda" if str(device).lower().startswith("cuda") else "cpu"
        
        remainder = n_permutations % world_size
        n_perm = int(n_permutations / world_size)
        n_perm = n_perm + 1 if rank < remainder else n_perm
        
        scores = model.decision_function(
            X_test,
            n_permutations=n_perm,
            batch_size=batch_size,
            device=device_str
        )
        
        if distributed and world_size > 1:
            all_scores = [None for _ in range(world_size)]
            dist.all_gather_object(all_scores, scores)
            combined_scores = np.concatenate(all_scores, axis=1) if len(all_scores) > 1 else all_scores[0]
        else:
            combined_scores = scores
        
        if len(combined_scores.shape) > 1:
            mean_scores = np.mean(combined_scores, axis=1)
        else:
            mean_scores = combined_scores
    
    if rank == 0:
        np.save(score_path, mean_scores)
        np.save(raw_score_path, combined_scores)
        print(f"[Evaluation] Scores saved to {score_path}")
        y_test_flat = y_test.flatten() if len(y_test.shape) > 1 else y_test
        print(f"[Evaluation] Mean score (normal): {np.mean(mean_scores[y_test_flat == 0]):.4f}")
        print(f"[Evaluation] Mean score (anomaly): {np.mean(mean_scores[y_test_flat == 1]):.4f}")
    
    model.model.train()
    
    auc_roc, auc_pr, f1, precision, recall = None, None, None, None, None
    if rank == 0:
        try:
            sys.path.insert(0, str(Path(__file__).parent.parent / 'evaluate'))
            from evaluate.compute_metrics import compute_detection_metrics as tabular_metrics
            y_test_flat = y_test.flatten() if len(y_test.shape) > 1 else y_test
            auc_roc, auc_pr, f1, precision, recall = tabular_metrics(y_test_flat, mean_scores)
            
            evaluate_file = exp_dir / 'evaluation_dispat.txt'
            result_line = f"{run_name:30s}: AUC-ROC: {auc_roc:.4f} ( 1), AUC-PR: {auc_pr:.4f} ( 1), F1: {f1:.4f} ( 1), P: {precision:.4f} ( 1), R: {recall:.4f} ( 1)"
            with open(evaluate_file, 'a') as f:
                f.write(result_line + '\n')
            print(f"[Evaluation] Metrics saved to {evaluate_file}")
        except Exception as e:
            print(f"[WARN] Could not compute metrics: {e}")
    
    return auc_roc, auc_pr, f1, precision, recall, mean_scores


class SPINDatasetForAnomalyDetection(torch.utils.data.Dataset):
    
    def __init__(self, spin_pairs, tokenizer, max_length=512):
        self.spin_pairs = spin_pairs
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"SPINDatasetForAnomalyDetection initialized:")
        print(f"  - {len(spin_pairs)} pairs")
        print(f"  - Max length: {max_length}")
    
    def __len__(self):
        return len(self.spin_pairs)
    
    def __getitem__(self, idx):
        pair = self.spin_pairs[idx]
        chosen_text = pair['chosen_text']      
        rejected_text = pair['rejected_text']  
        
        chosen_encoded = self.tokenizer(
            chosen_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        rejected_encoded = self.tokenizer(
            rejected_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'chosen_input_ids': chosen_encoded['input_ids'].squeeze(0),
            'chosen_attention_mask': chosen_encoded['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_encoded['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_encoded['attention_mask'].squeeze(0),
        }


def main():
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(local_rank)
    
    args = get_args()
    distributed = dist.is_available() and dist.is_initialized()
    rank = dist.get_rank() if distributed else 0
    if rank == 0:
        X_train, X_test, y_train, y_test = load_data(args)
    if distributed:
        dist.barrier()
    if rank != 0:
        X_train, X_test, y_train, y_test = load_data(args)
    if distributed:
        dist.barrier()
    
    spin_pairs_all = []
    
    if args.iteration == 0:
        iter_indices = [0]
    elif args.iteration == 1:
        iter_indices = [0, 1]
    else:
        iter_indices = [args.iteration - 1, args.iteration]
    
    for iter_idx in iter_indices:
        spin_data_path = Path(args.spin_data_dir) / f"iter{iter_idx}" / f"spin_dataset_iter{iter_idx}.pkl"
        if not os.path.exists(spin_data_path):
            raise ValueError(f"SPIN dataset for iteration {iter_idx} not found at {spin_data_path}. "
                           f"Please run generate_anomaly_samples_vllm.py with --iteration {iter_idx} first!")
        
        print(f"Loading SPIN dataset from iteration {iter_idx}: {spin_data_path}")
        with open(spin_data_path, 'rb') as f:
            spin_pairs_iter = pkl.load(f)
        spin_pairs_all.extend(spin_pairs_iter)
        print(f"  Loaded {len(spin_pairs_iter)} pairs from iteration {iter_idx}")
    
    spin_pairs = spin_pairs_all
    if args.iteration == 0:
        print(f"\nTotal loaded: {len(spin_pairs)} SPIN pairs from iteration 0")
    elif args.iteration == 1:
        print(f"\nTotal loaded: {len(spin_pairs)} SPIN pairs from iterations 0 and 1")
    else:
        print(f"\nTotal loaded: {len(spin_pairs)} SPIN pairs from iterations {args.iteration - 1} and {args.iteration}")
    
    run_name = get_spin_run_name(args)
    model_save_path = args.save_dir / run_name
    
    log_file = args.exp_dir / 'train_log.txt'
    
    if os.path.exists(model_save_path):
        print(f"Model exists at {model_save_path}, skip training")
        return
    
    if rank == 0:
        log_f = open(log_file, 'w', buffering=1)
        log_f.write(f"Training command arguments:\n")
        for arg, value in vars(args).items():
            log_f.write(f"  {arg}: {value}\n")
        log_f.write("\n" + "="*80 + "\n\n")
        log_f.flush()
    
    efficient_finetuning = 'lora' if args.lora else ''
    max_length_dict = get_max_length_dict(args.dataset)
    text_columns = get_text_columns(args.dataset)
    
    if args.iteration == 0:
        ref_path = Path(args.reference_model_dir)
        if (ref_path / 'config.json').exists():
            ref_model_dir = ref_path
        else:
            from SFT.train_anollm import get_run_name as get_base_run_name
            base_args = type('Args', (), {
                'model': args.model,
                'lr': args.lr,
                'binning': args.binning,
                'no_random_permutation': args.no_random_permutation,
                'lora': args.lora
            })()
            base_run_name = get_base_run_name(base_args)
            ref_model_dir = ref_path / 'models' / base_run_name
        if not (ref_model_dir / 'config.json').exists():
            raise ValueError(f"Reference model (finetuned) not found at {ref_model_dir}. "
                           f"For iteration 0, please provide path to finetuned model from train_anollm.py")
        print(f"[Iteration {args.iteration}] Loading reference model (finetuned) from: {ref_model_dir}")
    else:
        prev_iter = args.iteration - 1
        prev_run_name = get_spin_run_name(type('Args', (), {
            'model': args.model,
            'lr': args.lr,
            'binning': args.binning,
            'no_random_permutation': args.no_random_permutation,
            'lora': args.lora,
            'iteration': prev_iter,
            'max_steps': args.max_steps,
            'epsilon': args.epsilon,
            'beta': args.beta,
            'f_divergence_type': args.f_divergence_type
        })())
        
        ref_path = Path(args.reference_model_dir)
        if (ref_path / 'config.json').exists() and 'iter' in str(ref_path):
            ref_model_dir = ref_path
        else:
            ref_model_dir = args.exp_dir / 'models' / prev_run_name
        
        if not (ref_model_dir / 'config.json').exists():
            raise ValueError(f"SPIN model from iteration {prev_iter} not found at {ref_model_dir}. "
                           f"Please train iteration {prev_iter} first!")
        print(f"[Iteration {args.iteration}] Loading reference model (SPIN iter{prev_iter}) from: {ref_model_dir}")

    reference_model = AnoLLM(
        str(ref_model_dir),
        efficient_finetuning=efficient_finetuning,
        max_length_dict=max_length_dict,
        textual_columns=text_columns,
        no_random_permutation=args.no_random_permutation
    )
    
    print(f"[Iteration {args.iteration}] Initializing policy model from reference model")
    policy_model = AnoLLM(
        str(ref_model_dir),
        batch_size=args.batch_size,
        max_steps=args.max_steps,
        efficient_finetuning=efficient_finetuning,
        max_length_dict=max_length_dict,
        textual_columns=text_columns,
        no_random_permutation=args.no_random_permutation,
        bf16=True,
        learning_rate=args.lr,
    )
    
    device = torch.device(f"cuda:{local_rank}") if use_cuda else torch.device("cpu")
    policy_model.model.to(device)
    reference_model.model.to(device)
    
    policy_model.model.train()
    for param in policy_model.model.parameters():
        param.requires_grad = True
    
    reference_model.model.eval()
    for param in reference_model.model.parameters():
        param.requires_grad = False
    
    if distributed and dist.get_world_size() > 1:
        if use_cuda:
            policy_model.model = torch.nn.parallel.DistributedDataParallel(
                policy_model.model, device_ids=[local_rank], output_device=local_rank
            )
        else:
            policy_model.model = torch.nn.parallel.DistributedDataParallel(policy_model.model)
    
    print(f"Loading SPIN dataset with {len(spin_pairs)} pairs...")
    spin_dataset = SPINDatasetForAnomalyDetection(
        spin_pairs,
        policy_model.tokenizer,
        max_length=2048
    )
    
    print("Loading test data for evaluation...")
    _, X_test, _, y_test = load_data(args)
    if rank == 0:
        print(f"Test data shape: {X_test.shape}, y_test shape: {y_test.shape}")
    
    print("Starting SPIN training...")
    policy_model.model.train()
    
    optimizer = torch.optim.AdamW(
        policy_model.model.parameters(),
        lr=args.lr,
    )
    
    def collate_fn(batch):
        chosen_input_ids = torch.stack([item['chosen_input_ids'] for item in batch])
        chosen_attention_mask = torch.stack([item['chosen_attention_mask'] for item in batch])
        rejected_input_ids = torch.stack([item['rejected_input_ids'] for item in batch])
        rejected_attention_mask = torch.stack([item['rejected_attention_mask'] for item in batch])
        return {
            'chosen_input_ids': chosen_input_ids,
            'chosen_attention_mask': chosen_attention_mask,
            'rejected_input_ids': rejected_input_ids,
            'rejected_attention_mask': rejected_attention_mask,
        }
    
    dataloader = torch.utils.data.DataLoader(
        spin_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    
    start_time = time.time()
    global_step = 0
    
    eval_n_permutations = args.eval_n_permutations
    eval_batch_size = args.eval_batch_size
    
    for epoch in range(args.max_steps // len(dataloader) + 1):
        for batch_idx, batch in enumerate(dataloader):
            if global_step >= args.max_steps:
                break
            
            batch_chosen = {
                'input_ids': batch['chosen_input_ids'],
                'attention_mask': batch['chosen_attention_mask']
            }
            batch_rejected = {
                'input_ids': batch['rejected_input_ids'],
                'attention_mask': batch['rejected_attention_mask']
            }
            
            loss, chosen_reward, rejected_reward = spin_loss(
                policy_model.model,
                reference_model.model,
                batch_chosen,
                batch_rejected,
                args.beta,
                args.epsilon,
                device,
                f_divergence_type=args.f_divergence_type
            )
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_model.model.parameters(), 0.7)
            optimizer.step()
            
            global_step += 1
            
            if rank == 0 and global_step % 100 == 0:
                log_msg = (f"Step {global_step}/{args.max_steps}, Loss: {loss.item():.7f}, "
                          f"Chosen Reward: {chosen_reward.mean().item():.7f}, "
                          f"Rejected Reward: {rejected_reward.mean().item():.7f}")
                print(log_msg)
                log_f.write(log_msg + '\n')
                log_f.flush()
            
            if global_step >= args.max_steps:
                break
    
    end_time = time.time()
    
    if rank == 0:
        print("\n" + "="*80)
        print(f"[Evaluation after Iteration {args.iteration}]")
        print("="*80)
        log_f.write(f"\n[Evaluation after Iteration {args.iteration}]\n")
        
        auc_roc, auc_pr, f1, precision, recall, mean_scores = None, None, None, None, None, None
        try:
            print("Loading test data for evaluation...")
            _, X_test_eval, _, y_test_eval = load_data(args)
            print(f"Test data shape: {X_test_eval.shape}, y_test shape: {y_test_eval.shape}")
            
            world_size = dist.get_world_size() if distributed else 1
            auc_roc, auc_pr, f1, precision, recall, mean_scores = evaluate_model_metrics(
                policy_model,
                X_test_eval,
                y_test_eval,
                args,
                run_name,
                args.exp_dir,
                device,
                n_permutations=eval_n_permutations,
                batch_size=eval_batch_size,
                distributed=distributed,
                rank=rank,
                world_size=world_size
            )
            
            if auc_roc is not None:
                eval_msg = (f"Iteration {args.iteration} - AUC-ROC: {auc_roc:.4f}, AUC-PR: {auc_pr:.4f}, "
                           f"F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
                print(eval_msg)
                log_f.write(eval_msg + '\n')
                
                metrics_summary_path = args.exp_dir / f'metrics_iter{args.iteration}.txt'
                with open(metrics_summary_path, 'w') as f:
                    f.write(f"Iteration: {args.iteration}\n")
                    f.write(f"Final Step: {global_step}\n")
                    f.write(f"AUC-ROC: {auc_roc:.4f}\n")
                    f.write(f"AUC-PR: {auc_pr:.4f}\n")
                    f.write(f"F1: {f1:.4f}\n")
                    f.write(f"Precision: {precision:.4f}\n")
                    f.write(f"Recall: {recall:.4f}\n")
                print(f"[Evaluation] Metrics summary saved to {metrics_summary_path}")
        except Exception as e:
            print(f"[WARNING] Evaluation failed with error: {e}")
            print(f"[WARNING] Model will still be saved. Error details:")
            import traceback
            traceback.print_exc()
            log_f.write(f"\n[WARNING] Evaluation failed: {e}\n")
        
        train_time_msg = f"Training time: {end_time - start_time:.2f}s"
        print(train_time_msg)
        log_f.write("\n" + train_time_msg + '\n')
        log_f.close()
        
        run_time_dir = args.exp_dir / "run_time" / "train"
        os.makedirs(run_time_dir, exist_ok=True)
        run_time_path = run_time_dir / f"{run_name}.txt"
        with open(run_time_path, 'w') as f:
            f.write(str(end_time - start_time))
        
        print(f"Saving final SPIN model to {model_save_path}")
        policy_model.save_pretrained(model_save_path)
        
        print("\n" + "="*80)
        print(f"Training Summary - Iteration {args.iteration}")
        print("="*80)
        if auc_roc is not None:
            print(f"AUC-ROC: {auc_roc:.4f}")
            print(f"AUC-PR: {auc_pr:.4f}")
            print(f"F1: {f1:.4f}")
        print(f"Model saved to: {model_save_path}")
        print(f"Scores saved to: {args.exp_dir / 'scores' / f'{run_name}.npy'}")
        print("="*80)
    
    if distributed:
        dist.destroy_process_group()


if __name__ == "__main__":
    world_size_env = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size_env > 1:
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        dist.init_process_group(backend=backend)
    main()
