import os
import re
import csv
import time
import torch
import argparse
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, BitsAndBytesConfig
from transformers import get_cosine_schedule_with_warmup
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from tqdm.auto import tqdm
from openai import OpenAI
import json
import coloredlogs, logging
import datetime
import torch.multiprocessing as mp
from rapidfuzz import fuzz
import httpx
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from human_eval.execution import check_correctness
from human_eval.evaluation import evaluate_functional_correctness, estimate_pass_at_k

logger = None
mp.set_start_method('spawn', force=True)

os.environ['HF_TOKEN'] = 'HF_TOKEN'
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EXPERT_TYPES = 'hosted'  # 'hosted' or 'local'
# MODEL_NAMES = ["/home/models/Qwen_Qwen2-0.5B-Instruct", "/home/models/Phi-4-mini-instruct", "/home/models/unsloth_Llama-3.2-1B-Instruct", "/home/models/Phi-4-mini-instruct", "/home/models/unsloth_Llama-3.2-1B-Instruct"]
# MODEL_NAMES = ["/home/models/Phi-4-mini-instruct"] * 3
# MODEL_NAMES = ["Qwen/Qwen2.5-Coder-3B-Instruct"] * 3
# MODEL_NAMES=["/home/models/Qwen3-8B"] * 3

ORACLE_TYPE = 'hosted'  # 'hosted' or 'local'
ORACLE_NAME = "gpt-4.1"

SHARED_ENCODER_NAME = "google-bert/bert-base-uncased"
BATCH_SIZE = 2
LAMBDA_SYMM = 0.1
LAMBDA_SPARSITY = 0.01
NUM_SHOTS = 3
USE_QUANTIZED = True
MAX_NEW_TOKENS = 1024

class SelectionLogger:
    def __init__(self, model_names, log_dir="selection_logs"):
        self.model_names = model_names
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        self.train_log_path = os.path.join(log_dir, f"train_selections_{timestamp}.csv")
        self.inference_log_path = os.path.join(log_dir, f"inference_selections_{timestamp}.csv")

        with open(self.train_log_path, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['step', 'sample_idx', 'model_name', 'selection_prob', 'phase'])

        with open(self.inference_log_path, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['sample_idx', 'selected_models', 'selection_probs', 'final_model'])

    def log_batch(self, batch_indices, selection_probs, phase='train', selected_indices=None):
        if phase == 'train':
            with open(self.train_log_path, 'a') as f:
                writer = csv.writer(f)
                for i, sample_idx in enumerate(batch_indices):
                    for model_idx, prob in enumerate(selection_probs[i]):
                        writer.writerow([sample_idx, self.model_names[model_idx], prob, phase])
        else:
            with open(self.inference_log_path, 'a') as f:
                writer = csv.writer(f)
                for i, sample_idx in enumerate(batch_indices):
                    selected = selected_indices[i]
                    selected_names = [self.model_names[idx] for idx in selected]
                    probs = selection_probs[i]
                    writer.writerow([sample_idx, '|'.join(selected_names), '|'.join(map(str, probs)), selected_names[-1]])

    def get_selection_stats(self):
        stats = {'total_selections': {name: 0 for name in self.model_names},
                 'average_prob': {name: 0.0 for name in self.model_names}}

        with open(self.train_log_path, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                stats['total_selections'][row['model_name']] += 1
                stats['average_prob'][row['model_name']] += float(row['selection_prob'])

        for name in self.model_names:
            if stats['total_selections'][name] > 0:
                stats['average_prob'][name] /= stats['total_selections'][name]

        return stats

class CollaborationDataset(Dataset):
    def __init__(self, args, samples, shared_encoder, shared_tokenizer):
        self.samples = samples
        self.shared_encoder = shared_encoder
        self.shared_tokenizer = shared_tokenizer
        self.args = args

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        if self.args.task == 'humaneval':
            return sample
        return {
            'prompt': sample['prompt'],
            'answer': sample['answer'],
        }

class CollaborationController(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, num_models=3, max_seq_len=3, tau=1.0, use_cosine_bias=True):
        super().__init__()
        self.num_models = num_models
        self.max_seq_len = max_seq_len
        self.tau = tau
        self.use_cosine_bias = use_cosine_bias
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.oracle_proj = nn.Linear(input_dim, hidden_dim)
        self.oracle_attention = nn.MultiheadAttention(hidden_dim, num_heads=1, batch_first=True)

        self.model_attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)

        self.residual_proj = nn.Linear(hidden_dim, input_dim)

        self.C_head = nn.Linear(hidden_dim, num_models * num_models)

        self.seq_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_models)
        )

        for layer in [self.input_proj, self.residual_proj, *self.seq_head]:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0.1)

        self.perf_estimator = nn.Linear(hidden_dim, 1)

        self.tau_min = 0.5
        self.tau_decay = 0.999
        self.min_k = 1
        self.current_k = max_seq_len
        self.k_decay = 0.999
        self.k_update_freq = 100
        self.step_count = 0
        self.register_buffer('selection_quality', torch.zeros(1))
        self.quality_alpha = 0.9
        self.length_cost_weight = 0.2 
        self.min_length_penalty = 0.1

    def compute_cosine_matrix(self, shared_reps):
        B, N, D = shared_reps.shape
        norm_reps = F.normalize(shared_reps, dim=-1)
        return torch.einsum("bid,bjd->bij", norm_reps, norm_reps)

    def compute_collaboration_matrix(self, encoded_input, sim_matrix=None):
        B = encoded_input.size(0)
        C_logits = self.C_head(encoded_input).view(B, self.num_models, self.num_models)

        if self.use_cosine_bias and sim_matrix is not None:
            C_logits = C_logits + sim_matrix

        mask = torch.eye(self.num_models, device=C_logits.device).bool()
        C_logits = C_logits.masked_fill(mask, float('-inf'))
        return F.softmax(C_logits, dim=-1)

    def compute_sequence_distribution(self, encoded_input, C_soft, shared_reps):
        logits = self.seq_head(encoded_input)

        baseline = torch.ones_like(logits) * 0.1
        logits = logits + baseline

        if shared_reps is not None:
            perf_scores = self.perf_estimator(shared_reps).squeeze(-1)
            logits = logits + perf_scores

        if C_soft is not None:
            collab_importance = 0.5 * (C_soft.mean(dim=1) + C_soft.mean(dim=2))
            logits = logits + collab_importance

        length_penalty = torch.arange(1, self.num_models+1, device=logits.device).float() * self.length_cost_weight
        length_penalty = length_penalty + self.min_length_penalty
        logits = logits - length_penalty.unsqueeze(0)

        probs = F.gumbel_softmax(logits, tau=self.tau, hard=False, dim=-1)
        probs = probs.clamp(min=1e-4)
        return probs / probs.sum(dim=-1, keepdim=True)

    def update_k(self, new_quality):
        self.step_count += 1
        self.selection_quality = (self.quality_alpha * self.selection_quality +
                                (1 - self.quality_alpha) * new_quality)

        if self.step_count % self.k_update_freq == 0:
            if self.selection_quality > 0.7:
                self.current_k = max(self.min_k, int(self.current_k * self.k_decay))

    def forward(self, input_embedding, shared_reps=None, oracle_emb=None):
        h = self.input_proj(input_embedding)

        if oracle_emb is not None and self.training:
            oracle_emb = self.oracle_proj(oracle_emb).unsqueeze(1)
            h_attn = h.unsqueeze(1)
            attn_out, _ = self.oracle_attention(h_attn, oracle_emb, oracle_emb)
            h = h + attn_out.squeeze(1)

        if shared_reps is not None:
            B, N, D = shared_reps.shape
            shared_reps_proj = self.input_proj(shared_reps.view(B*N, D)).view(B, N, -1)

            attn_out, _ = self.model_attention(shared_reps_proj, shared_reps_proj, shared_reps_proj)

            attn_out = self.residual_proj(attn_out)
            shared_reps = shared_reps + attn_out

            shared_reps = self.input_proj(shared_reps.view(B*N, D)).view(B, N, -1)

        sim_matrix = self.compute_cosine_matrix(shared_reps) if self.use_cosine_bias and shared_reps is not None else None
        C_soft = self.compute_collaboration_matrix(h, sim_matrix)

        seq_gumbel = self.compute_sequence_distribution(h, C_soft, shared_reps)

        if not self.training:
            topk_values, topk_indices = torch.topk(seq_gumbel, k=self.current_k, dim=1)
            seq_gumbel = torch.zeros_like(seq_gumbel).scatter(1, topk_indices, topk_values)

        self.tau = max(self.tau_min, self.tau * self.tau_decay)

        with torch.no_grad():
            if self.training:
                selected_counts = (seq_gumbel > 0.5).sum(dim=1).float().mean()
                selected_lengths = (seq_gumbel > 0.5).sum(dim=1).float()
                selection_info = {
                    'selected_counts': selected_counts.item(),
                    'selection_probs': seq_gumbel.detach().cpu().numpy()
                }
                selection_info['length_cost'] = selected_lengths.mean().item() * self.length_cost_weight
            else:
                topk_values, topk_indices = torch.topk(seq_gumbel, k=self.max_seq_len, dim=1)
                selection_info = {
                    'selected_indices': topk_indices.cpu().numpy(),
                    'selection_probs': topk_values.cpu().numpy()
                }
                selection_info['length_cost'] = self.current_k * self.length_cost_weight

        return C_soft, seq_gumbel, selection_info

    @classmethod
    def from_saved_state(cls, save_path, device=DEVICE):
        checkpoint = torch.load(save_path, map_location=device)

        controller = cls(
            input_dim=checkpoint['config']['input_dim'],
            hidden_dim=checkpoint['config']['hidden_dim'],
            num_models=checkpoint['config']['num_models'],
            max_seq_len=checkpoint['config']['max_seq_len'],
            tau=checkpoint['tau'],
            use_cosine_bias=True
        ).to(device)

        controller.load_state_dict(checkpoint['state_dict'])
        controller.current_k = checkpoint['current_k']
        controller.tau = checkpoint['tau']
        controller.selection_quality = torch.tensor(checkpoint['selection_quality'], device=device)
        controller.step_count = checkpoint['step_count']

        for param, value in checkpoint['config'].items():
            setattr(controller, param, value)

        return controller, checkpoint

def prepare_datasets(args):
    if args.task == "mmlu":
        dataset = load_dataset(args.dataset, 'all', split="validation[:]")
    elif args.task == "mmlupro":
        dataset = load_dataset(args.dataset, split="validation")
    elif args.task == "gsm8k":
        dataset = load_dataset(args.dataset, 'main', split="test[:]")
    else:
        dataset = load_dataset(args.dataset, split="test[:1%]")

    samples = []
    for item in dataset:
        if args.task == 'squad':
            prompt = construct_few_shot_prompt(0, dataset) + f"Q: {item['question']}\nContext: {item['context']}\nA:"
            answer = item["answers"]["text"][0] if item["answers"]["text"] else ""
        elif args.task == 'mmlu':
            prompt = format_prompt(item, 'mmlu')
            answer = ["(A)", "(B)", "(C)", "(D)"][item["answer"]] + " " + item["choices"][item["answer"]]
        elif args.task == 'mmlupro':
            prompt = format_prompt(item, 'mmlupro')
            answer = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)"][item["answer_index"]] + " " + item["options"][item["answer_index"]]
        elif args.task == 'humaneval':
            prompt = format_prompt(item, 'humaneval')
            answer = item["canonical_solution"]
        elif args.task == 'gsm8k':
            prompt = format_prompt(item, 'gsm8k')
            answer = item["answer"]
        else:
            raise ValueError("Unsupported task")

        if args.task == 'humaneval':
            samples.append({
                'prompt': prompt,
                'answer': answer,
                'task_id': item['task_id'],
                'test': item['test'],
                'entry_point': item['entry_point']
            })
        else:
            samples.append({'prompt': prompt, 'answer': answer})

    return samples

def extract_oracle_signals(oracle_model, oracle_tokenizer, prompts, shared_encoder, shared_tokenizer, models, tokenizers):
    B = len(prompts)
    N = len(MODEL_NAMES)

    oracle_outputs = generate_text(oracle_model, oracle_tokenizer, prompts)
    oracle_emb = get_shared_embedding(oracle_outputs, shared_encoder, shared_tokenizer)

    model_outputs = []
    for name, model, tokenizer in zip(MODEL_NAMES, models, tokenizers):
        outputs = generate_text(model, tokenizer, prompts, name)
        model_outputs.append(outputs)

    oracle_matrix = torch.zeros(B, N, N, device=DEVICE)
    importance_scores = torch.zeros(B, N, device=DEVICE)

    for i in range(N):
        model_emb = get_shared_embedding(model_outputs[i], shared_encoder, shared_tokenizer)
        importance_scores[:, i] = F.cosine_similarity(oracle_emb, model_emb, dim=1)

        for j in range(N):
            if i != j:
                other_emb = get_shared_embedding(model_outputs[j], shared_encoder, shared_tokenizer)
                combined_emb = (model_emb + other_emb)/2
                oracle_matrix[:,i,j] = F.cosine_similarity(combined_emb, oracle_emb, dim=1)

    return F.softmax(oracle_matrix, dim=-1), F.softmax(importance_scores, dim=-1)

def compute_sparsity_loss(C_soft):
    entropy = -(C_soft * (C_soft.clamp(min=1e-9).log())).sum(dim=-1)
    return entropy.mean()

def train_controller(args):
    if EXPERT_TYPES == 'hosted':
        print(API_KEYS)
        print(BASE_URLS)
        models = [
            OpenAI(
                api_key=api_key,
                base_url=base_url
            ) for _, api_key, base_url in zip(MODEL_NAMES, API_KEYS, BASE_URLS)
        ]
        tokenizers = [None] * len(MODEL_NAMES) 
    else:
        models = [load_model_quantized(name) if USE_QUANTIZED else AutoModelForCausalLM.from_pretrained(name).to(DEVICE) for name in MODEL_NAMES]
        tokenizers = [AutoTokenizer.from_pretrained(name) for name in MODEL_NAMES]
        for t in tokenizers:
            t.pad_token = t.eos_token
            t.padding_side = 'left'

    if ORACLE_TYPE == 'hosted':
        oracle_tokenizer = None
        oracle_model = OpenAI(
            # base_url="",
            api_key="API_KEY"
        )
    else:
        oracle_tokenizer = AutoTokenizer.from_pretrained(ORACLE_NAME)
        oracle_tokenizer.pad_token = oracle_tokenizer.eos_token
        oracle_tokenizer.padding_side = 'left'
        oracle_model = load_model_quantized(ORACLE_NAME) if USE_QUANTIZED else AutoModelForCausalLM.from_pretrained(ORACLE_NAME).to(DEVICE)

    shared_tokenizer = AutoTokenizer.from_pretrained(SHARED_ENCODER_NAME)
    shared_encoder = AutoModel.from_pretrained(SHARED_ENCODER_NAME).to(DEVICE)

    samples = prepare_datasets(args)
    dataset = CollaborationDataset(args, samples, shared_encoder, shared_tokenizer)
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True
    )

    controller = CollaborationController(num_models=len(models), max_seq_len=len(models)-1, use_cosine_bias=True).to(DEVICE)
    optimizer = torch.optim.Adam(controller.parameters(), lr=1e-3)

    num_training_steps = len(dataloader) * args.epochs
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * num_training_steps),
        num_training_steps=num_training_steps,
    )

    loss_weights = {
        'util': 0.5,
        'distill': 0.5,
        'symm': args.lambda_symm,
        'sparse': args.lambda_sparse,
        'diversity': 0.1,
        'oracle': 0.3,
        'selection': 1.0,
        'length': 0.5
    }
    selection_logger = SelectionLogger(MODEL_NAMES)

    for epoch in range(args.epochs):
        controller.train()
        epoch_loss = 0.0

        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")):
            prompts = batch['prompt']

            if args.task == 'humaneval':
                golds = batch
            else:
                golds = batch['answer']

            with torch.no_grad():
                input_embeds = get_shared_embedding(prompts, shared_encoder, shared_tokenizer)

                if EXPERT_TYPES == 'hosted':
                    shared_reps = torch.stack([
                        get_shared_embedding(generate_text(m, t, prompts, n), shared_encoder, shared_tokenizer)
                        for (n, m, t) in zip(MODEL_NAMES, models, tokenizers)
                    ], dim=1).to(DEVICE)
                else:
                    shared_reps = torch.stack([
                        get_shared_embedding(generate_text(m, t, prompts), shared_encoder, shared_tokenizer)
                        for (m, t) in zip(models, tokenizers)
                    ], dim=1).to(DEVICE)

                oracle_matrix, importance_scores = extract_oracle_signals(
                    oracle_model, oracle_tokenizer, prompts, shared_encoder, shared_tokenizer, models, tokenizers
                )
                oracle_emb = get_shared_embedding(
                    generate_text(oracle_model, oracle_tokenizer, prompts),
                    shared_encoder, shared_tokenizer
                ).to(DEVICE)

            C_soft, seq_gumbel, selection_info = controller(input_embeds, shared_reps, oracle_emb)
            selection_logger.log_batch(
                batch_indices=batch_idx * args.batch_size + torch.arange(len(prompts)),
                selection_probs=selection_info['selection_probs'],
                phase='train'
            )

            if batch_idx % 2 == 0:
                logger.info(f"Selection probs: {seq_gumbel.mean(dim=0).detach().cpu().numpy()}")
                logger.info(f"Max prob: {seq_gumbel.max().item():.4f}, Min prob: {seq_gumbel.min().item():.4f}")

            oracle_alignment_loss = F.mse_loss(C_soft, oracle_matrix) + F.mse_loss(seq_gumbel, importance_scores)

            batch_qualities = []
            final_outputs = []
            distillation_loss = 0
            selected_counts = torch.zeros(len(models), device=DEVICE)

            for b in range(len(prompts)):
                current_prompt = prompts[b]
                reasoning = ""

                topk_probs, topk_indices = torch.topk(seq_gumbel[b], k=min(args.seq_len, seq_gumbel[b].size(0)))
                selected_counts.scatter_add_(0, topk_indices, topk_probs)

                oracle_emb = get_shared_embedding(
                    generate_text(oracle_model, oracle_tokenizer, [current_prompt])[0],
                    shared_encoder, shared_tokenizer
                ).to(DEVICE)

                model_outputs = []
                model_shared_reps = []
                for k in range(min(args.seq_len, len(topk_indices))):
                    selected_idx = topk_indices[k].item()
                    model = models[selected_idx]

                    full_input = reasoning + "\n" + current_prompt if reasoning else current_prompt

                    if EXPERT_TYPES != 'hosted':
                        tokenizer = tokenizers[selected_idx]
                        full_input = tokenizer.apply_chat_template([{'role': 'user', 'content': full_input}], tokenize=False, add_generation_prompt=False)
                    else:
                        tokenizer = None

                    logger.debug(full_input)

                    out = generate_text(model, tokenizer, [full_input], MODEL_NAMES[selected_idx])[0]

                    model_outputs.append(out.strip())
                    reasoning += f"\nAssistant {k+1}'s Response: {out}"

                    out_shared_rep = get_shared_embedding(out, shared_encoder, shared_tokenizer)
                    model_shared_reps.append(out_shared_rep)
                    out_shared_rep_avg = out_shared_rep.mean(dim=0)
                    oracle_emb_avg = oracle_emb.mean(dim=0)
                    distillation_loss += F.mse_loss(out_shared_rep_avg, oracle_emb_avg)

                similarities = [
                    F.cosine_similarity(
                        oracle_emb.mean(dim=0).unsqueeze(0),
                        e.mean(dim=0).unsqueeze(0),
                        dim=1
                    )
                    for e in model_shared_reps
                ]
                batch_quality = torch.mean(torch.stack(similarities))
                batch_qualities.append(batch_quality)

                final_output = model_outputs[-1] if model_outputs else ""
                final_outputs.append(final_output)

            avg_quality = torch.mean(torch.stack(batch_qualities)).item()
            controller.update_k(avg_quality)

            util_loss = compute_utility_loss(final_outputs, golds, args.task)
            symm_loss = (C_soft - C_soft.transpose(1, 2)).abs().mean()
            sparse_loss = compute_sparsity_loss(C_soft)
            diversity_loss = -selected_counts.var() / len(models)
            selection_loss = -seq_gumbel.mean()
            length_cost = selection_info['length_cost']

            total_loss = (
                loss_weights['util'] * util_loss +
                loss_weights['distill'] * distillation_loss +
                loss_weights['symm'] * symm_loss +
                loss_weights['sparse'] * sparse_loss +
                loss_weights['diversity'] * diversity_loss +
                loss_weights['oracle'] * oracle_alignment_loss +
                loss_weights['selection'] * selection_loss +
                loss_weights['length'] * length_cost
            )

            logger.info({
                'k': controller.current_k,
                'util_loss': util_loss,
                'distill_loss': distillation_loss,
                'symm_loss': symm_loss,
                'sparse_loss': sparse_loss,
                'diversity_loss': diversity_loss,
                'oracle_alignment_loss': oracle_alignment_loss,
                'selection_loss': selection_loss,
                'length_cost': length_cost,
                'total_loss': total_loss.item()
            })

            epoch_loss += total_loss.item()

            optimizer.zero_grad()
            total_loss.backward()

            grad_norms = [p.grad.norm().item() for p in controller.parameters() if p.grad is not None]
            logger.info(f"Gradient norms - Min: {min(grad_norms):.4f}, Max: {max(grad_norms):.4f}, Mean: {np.mean(grad_norms):.4f}")

            torch.nn.utils.clip_grad_norm_(controller.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            if batch_idx % 10 == 0:
                logger.info(
                    f"Epoch {epoch+1} | Batch {batch_idx} | Loss={total_loss.item():.4f} | "
                    f"Util={util_loss:.4f} | OracleAlign={oracle_alignment_loss:.4f} | "
                    f"Selection={selection_loss:.4f}"
                )

        avg_epoch_loss = epoch_loss / len(dataloader)
        logger.info(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")

        controller_state = {
            'state_dict': controller.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'current_k': controller.current_k,
            'tau': controller.tau,
            'selection_quality': controller.selection_quality,
            'step_count': controller.step_count,
            'epoch': epoch + 1,
            'loss_weights': loss_weights,
            'model_names': MODEL_NAMES,
            'config': {
                'input_dim': controller.input_dim,
                'hidden_dim': controller.hidden_dim,
                'num_models': controller.num_models,
                'max_seq_len': controller.max_seq_len,
                'tau_min': controller.tau_min,
                'tau_max': controller.tau_max,
                'min_k': controller.min_k,
                'k_decay': controller.k_decay,
                'k_update_freq': controller.k_update_freq,
                'quality_alpha': controller.quality_alpha,
                'length_cost_weight': controller.length_cost_weight,
                'min_length_penalty': controller.min_length_penalty
            }
        }

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = f"{args.output_dir}/collaboration_controller_{args.task}_epoch{epoch+1}_{timestamp}.pt"
        torch.save(controller_state, save_path)

    final_save_path = f"{args.output_dir}/collaboration_controller_{args.task}_final_{timestamp}.pt"
    torch.save(controller_state, final_save_path)
    logger.info(f"Saved final controller state to {final_save_path}")

def clean_and_extract_number(s):
    s = s.replace(',', '')
    s = s.split()[0]
    s = re.sub(r"[^\d]", "", s)

    return int(s) if s else None

def extract_final_number(text):
    if pd.isna(text):
        return None
    text = str(text)

    match = re.search(r'\\boxed\{(\d+(\.\d+)?)\}', text)
    if match:
        return float(match.group(1))

    match = re.search(r'\$(\d+(?:\.\d+)?)\$', text)
    if match:
        return float(match.group(1))

    match = re.findall(r'>>(\d+(?:\.\d+)?)', text)
    if match:
        return float(match[-1])

    numbers = re.findall(r'\d+(?:\.\d+)?', text)
    return float(numbers[-1]) if numbers else None

def load_model_quantized(name):
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )
    return AutoModelForCausalLM.from_pretrained(name, quantization_config=bnb_config, device_map="auto")

def compute_mmlu_loss(preds, golds):
    def index_to_option(idx):
        return f"({chr(ord('A') + int(idx))})"

    def find_and_format_A_to_D(text):
        pattern = re.compile(
            r'\(([A-D])\)'                               
            r'|\$\$\\boxed\{([A-D])\}\$\$'               
            r'|\$\$\\boxed\{\\text\{([A-D])\}\}\$\$',    
            re.MULTILINE
        )

        formatted_matches = []
        for match in re.findall(pattern, text):
            letter = next(filter(None, match))
            formatted_matches.append(f"({letter})")

        return formatted_matches[-1] if len(formatted_matches) > 0 else None

    def extract_option_from_response(response, gt=None):
        if gt is not None:
            gt = gt.strip()
            if response.find(gt) != -1:
                return gt
            else:
                return find_and_format_A_to_D(response.replace('\n', '').strip())
        return None

    correct = 0
    for pred, gold in zip(preds, golds):
        _pred = extract_option_from_response(pred)
        _gold = gold[:4]
        if _pred is not None and _gold is not None and _gold in _pred:
            correct += 1
    acc = correct / len(preds)
    return 1.0 - acc

def compute_gsm8k_loss(preds, golds):
    correct = 0
    for pred, gold in zip(preds, golds):
        pred_val = extract_final_number(pred)
        gold_val = extract_final_number(gold)
        if pred_val is not None and gold_val is not None and abs(pred_val - gold_val) < 1e-3:
            correct += 1
    acc = correct / len(preds)
    return 1.0 - acc

def compute_humaneval_loss(preds, golds, k=1):
    golds = [
        {key: value[i] for key, value in golds.items()}
        for i in range(len(next(iter(golds.values()))))
    ]
    logger.warn(golds)
    problems = {g['task_id']: g for g in golds}
    completions = {g['task_id']: [p.split('```python')[1].split('```')[0].strip()] for g, p in zip(golds, preds)}

    timeout=10.0
    n_workers=4
    with ThreadPoolExecutor(max_workers=n_workers) as executor:

        from collections import Counter, defaultdict

        futures = []
        completion_id = Counter()
        n_samples = 0
        results = defaultdict(list)

        print("Reading samples...")
        for task_id in tqdm(completions.keys()):
            args = (problems[task_id], completions[task_id][0], timeout, completion_id[task_id])
            future = executor.submit(check_correctness, *args)
            futures.append(future)
            completion_id[task_id] += 1
            n_samples += 1

        assert len(completion_id) == len(problems), "Some problems are not attempted."

        print("Running test suites...")
        for future in tqdm(as_completed(futures), total=len(futures)):
            result = future.result()
            logger.error(result)
            results[result["task_id"]].append((result["completion_id"], result))

    total, correct = [], []
    for result in results.values():
        result.sort()
        passed = [r[1]["passed"] for r in result]
        total.append(len(passed))
        correct.append(sum(passed))
    total = np.array(total)
    correct = np.array(correct)

    return 1.0 - estimate_pass_at_k(total, correct, 1).mean()

def fallback_token_loss(preds, golds, tokenizer):
    input_ids = tokenizer(preds, padding=True, truncation=True, return_tensors="pt")["input_ids"]
    target_ids = tokenizer(golds, padding=True, truncation=True, return_tensors="pt")["input_ids"]

    input_ids = input_ids[:, :target_ids.shape[1]].to(target_ids.device)
    loss = F.cross_entropy(input_ids.float(), target_ids, reduction='mean')
    return loss.item()

def compute_utility_loss(preds, golds, task_name, tokenizer=None, evaluator=None):
    task_name = task_name.lower()

    if "mmlu" in task_name:
        return compute_mmlu_loss(preds, golds)

    elif "gsm8k" in task_name:
        return compute_gsm8k_loss(preds, golds)

    elif "humaneval" in task_name:
        return compute_humaneval_loss(preds, golds, evaluator)

    else:
        if tokenizer is None:
            raise ValueError("Tokenizer required for fallback token-level loss")
        return fallback_token_loss(preds, golds, tokenizer)

def generate_text(model, tokenizer, prompts, infer_start=False, model_name='gpt-4.1'):
    print(tokenizer, model_name)
    import time
    time.sleep(5)
    if tokenizer is None or model_name is not None:
        responses = []
        for prompt in prompts:
            retry_count = 0
            while retry_count < 5:
                try:
                    response = model.chat.completions.create(
                        model=model_name,
                        max_tokens=4096 if not infer_start else 10,
                        messages=[
                            {"role": "user", "content": prompt},
                        ]
                    )
                    print(response.choices[0].message.content)
                    responses.append(response.choices[0].message.content.replace('Answer:', '').strip())
                    break
                except Exception as e:
                    logger.error(e)
                    retry_count += 1
                    wait_time = 2 ** retry_count
                    logger.error(f"Retry {retry_count}/5 after error: {e}. Waiting {wait_time}s.")
                    time.sleep(50 * wait_time)
            else:
                logger.error("[ERROR: Max retries exceeded]")
        return responses

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    outputs = [out[len(prompt):].strip() for out, prompt in zip(outputs, prompts)]
    return outputs

def get_shared_embedding(texts, encoder, tokenizer, device=DEVICE):
    if not texts or not all(isinstance(t, str) for t in texts):
        batch_size = len(texts) if isinstance(texts, list) else 1
        return torch.zeros(batch_size, encoder.config.hidden_size, device=device)

    cleaned_texts = []
    for text in texts:
        if not isinstance(text, str):
            text = str(text)
        cleaned_texts.append(text[:1000000])

    try:
        inputs = tokenizer(
            cleaned_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
            add_special_tokens=True,
            return_attention_mask=True
        ).to(device)

        if (inputs['input_ids'] >= tokenizer.vocab_size).any():
            invalid_ids = inputs['input_ids'][inputs['input_ids'] >= tokenizer.vocab_size]
            logger.warning(f"Found {len(invalid_ids)} invalid token IDs: {invalid_ids}")
            inputs['input_ids'][inputs['input_ids'] >= tokenizer.vocab_size] = tokenizer.unk_token_id

        if 'attention_mask' not in inputs:
            inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])

        with torch.no_grad():
            outputs = encoder(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  # CLS token

            if torch.isnan(embeddings).any() or torch.isinf(embeddings).any():
                logger.warning("Generated embeddings contain NaN/inf values")
                return torch.zeros(len(texts), encoder.config.hidden_size, device=device)

            return embeddings

    except Exception as e:
        logger.error(f"Embedding failed: {str(e)}")
        return torch.zeros(len(texts), encoder.config.hidden_size, device=device)

def evaluate(preds, golds):
    em_total, f1_total = 0.0, 0.0
    for pred, gold in zip(preds, golds):
        pred, gold = pred.strip().lower(), gold.strip().lower()
        em = int(pred == gold)
        em_total += em
        pred_tokens = set(pred.split())
        gold_tokens = set(gold.split())
        common = pred_tokens & gold_tokens
        f1 = 2 * len(common) / (len(pred_tokens) + len(gold_tokens) + 1e-8)
        f1_total += f1
    return em_total / len(preds), f1_total / len(preds)

def construct_few_shot_prompt(current_idx, dataset, num_shots=3):
    few_shots = []
    random_indices = torch.randperm(len(dataset)).tolist()
    random_indices = random_indices[:num_shots]
    for i in random_indices:
        if i == current_idx:
            continue

        q, a = dataset[i]["question"], dataset[i]["answer"]
        r = a.split('####')[0].strip()
        a = a.split('####')[-1].strip()
        shot = f"#### Example\nQuestion {q}\nReason: {r}\nAnswer: {a}"
        few_shots.append(shot)
    return "\n\n".join(few_shots)

def run_batch_collaborative_inference_selected(
    prompts, task, controller, models, model_tokenizers, tokenizer, encoder,
    device="cuda", top_k=None, threshold=0.0, early_stopping=True
):
    B = len(prompts)
    N = len(models)

    input_emb = get_shared_embedding(prompts, encoder, tokenizer).to(device)

    if EXPERT_TYPES == "hosted":
        shared_reps = torch.stack([
            get_shared_embedding(generate_text(m, t, prompts, True, n), encoder, tokenizer)
            for (n,m,t) in zip(MODEL_NAMES, models, model_tokenizers)
        ], dim=1).to(device)
    else:
        shared_reps = torch.stack([
            get_shared_embedding(generate_text(m, t, prompts, True), encoder, tokenizer)
            for (m,t) in zip(models, model_tokenizers)
        ], dim=1).to(device)

    controller.eval()
    with torch.no_grad():
        C_soft, seq_gumbel, selection_info = controller(input_emb, shared_reps)

    final_outputs = []
    all_selected_models = []
    stopping_points = []

    for i in range(B):
        current_prompt = prompts[i]
        selected_models = []
        previous_output = None
        stop_early = False
        reasoning_chain = []

        topk_values, topk_indices = torch.topk(seq_gumbel[i], k=controller.current_k)

        for k, model_idx in enumerate(topk_indices):
            if stop_early and early_stopping:
                break

            model_idx = model_idx.item()
            model = models[model_idx]
            tokenizer = model_tokenizers[model_idx]

            full_input = current_prompt
            if reasoning_chain:
                full_input = "\n\n".join(reasoning_chain) + "\n\n" + current_prompt

            output = generate_text(model, tokenizer, [full_input], MODEL_NAMES[model_idx])[0]
            reasoning_chain.append(f"Assistant {k+1} (Model {model_idx}): {output}")

            if previous_output is not None and early_stopping:
                string_similarity = fuzz.ratio(output, previous_output)/100
                if string_similarity > 0.8:
                    logger.info(f"Early stopping at model {k+1} due to similar outputs")
                    logger.info(f"String similarity: {string_similarity:.2f}")
                    stop_early = True
                    stopping_points.append(k+1)
                    break

            previous_output = output
            selected_models.append(model_idx)

        final_output = previous_output if previous_output else ""
        final_outputs.append(final_output)
        all_selected_models.append(selected_models)

    if early_stopping and stopping_points:
        avg_stopping_point = sum(stopping_points)/len(stopping_points)
        logger.info(f"Early stopping triggered {len(stopping_points)} times")
        logger.info(f"Average stopping point: {avg_stopping_point:.1f} models")
    else:
        pass

    return final_outputs, selection_info, all_selected_models, stopping_points if early_stopping else None

def format_prompt(example, task):
    if task == "mmlu":
        return """Above is the conversation history, with the most recent model output at the top.
Each model should carefully read *all previous outputs* and decide how to contribute next.
Your role is to coordinate with earlier outputs by either:
1. Building upon correct reasoning.
2. Correcting or refining mistakes.
3. Adding missing details.
4. Passing an intermediate or final answer if complete.

Always state explicitly what you are doing and why.
Avoid repeating identical reasoning unless you are clarifying or improving it. /no_think

Answer the following question as accurately as possible. Put your final answer as (A), (B), (C), or (D). All questions are single choice.\n\n""" + f"Question: {example['question']}\n(A) {example['choices'][0]}\n(B) {example['choices'][1]}\n(C) {example['choices'][2]}\n(D) {example['choices'][3]}\nAnswer: "
    elif task == "mmlupro":
        return """Above is the conversation history, with the most recent model output at the top.
Each model should carefully read *all previous outputs* and decide how to contribute next.
Your role is to coordinate with earlier outputs by either:
1. Building upon correct reasoning.
2. Correcting or refining mistakes.
3. Adding missing details.
4. Passing an intermediate or final answer if complete.

Always state explicitly what you are doing and why.
Avoid repeating identical reasoning unless you are clarifying or improving it. /no_think

Answer the following question as accurately as possible. Put your final answer as (A), (B), (C), (D), (E), (F), (G), (H), (I), (J). All questions are single choice.\n\n""" + f"Question: {example['question']}\n(A) {example['options'][0]}\n(B) {example['options'][1]}\n(C) {example['options'][2]}\n(D) {example['options'][3]}\n(E) {example['options'][4]}\n(F) {example['options'][5]}\n(G) {example['options'][6]}\n(H) {example['options'][7]}\n(I) {example['options'][8]}\n(J) {example['options'][9]}\nAnswer: "
    elif task == "humaneval":
        return """Above is the conversation history, with the most recent model output at the top.
Each model should carefully read *all previous outputs* and decide how to contribute next.
Your role is to coordinate with earlier outputs by either:
1. Building upon correct reasoning.
2. Correcting or refining mistakes.
3. Adding missing details.
4. Passing an intermediate or final answer if complete.

Always state explicitly what you are doing and why.
Avoid repeating identical reasoning unless you are clarifying or improving it. /no_think

Read the function signature and comments. Implement the function only. Put your solution enclosed in backticks.\n\n""" + "Question: ```python\n" + example["prompt"] + "```\nAnswer: ```python\n"
    elif task == "gsm8k":
        answer = example["answer"].split('####')[-1].strip()
        reasoning = example["answer"].split('####')[0].strip()
        return """Above is the conversation history, with the most recent model output at the top.
Each model should carefully read *all previous outputs* and decide how to contribute next.
Your role is to coordinate with earlier outputs by either:
1. Building upon correct reasoning.
2. Correcting or refining mistakes.
3. Adding missing details.
4. Passing an intermediate or final answer if complete.

Always state explicitly what you are doing and why. /no_think
Avoid repeating identical reasoning unless you are clarifying or improving it.

Follow the given examples and answer the mathematics problem.\n\n""" + """Problem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
Answer: There are 15 trees originally. Then there were 21 trees after the Grove workers planted some more. So there must have been 21 - 15 = 6 trees that were planted. The answer is 6.
###
Problem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
Answer: There are originally 3 cars. Then 2 more cars arrive. Now 3 + 2 = 5 cars are in the parking lot. The answer is 5.
###
Problem: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
Answer: Originally, Leah had 32 chocolates and her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39 pieces left in total. The answer is 39.
###
Problem: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
Answer: Jason had 20 lollipops originally. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8 lollipops. The answer is 8.
###
Problem: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
Answer: Shawn started with 5 toys. He then got 2 toys each from his mom and dad. So he got 2 * 2 = 4 more toys. Now he has 5 + 4 = 9 toys. The answer is 9.
###
Problem: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
Answer: There were originally 9 computers. For each day from monday to thursday, 5 more computers were installed. So 4 * 5 = 20 computers were added. Now 9 + 20 = 29 computers are now in the server room. The answer is 29.
###
Problem: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
Answer: Michael started with 58 golf balls. He lost 23 on Tuesday, and lost 2 more on wednesday. So he had 58 - 23 = 35 at the end of Tuesday, and 35 - 2 = 33 at the end of wednesday. The answer is 33.
###
Problem: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
Answer: Olivia had 23 dollars. She bought 5 bagels for 3 dollars each. So she spent 5 * 3 = 15 dollars. Now she has 23 - 15 = 8 dollars left. The answer is 8.\n\n""" + f"Question: {example['question']}\nAnswer: "
    else:
        context = example.get("context", "")
        return f"Question: {example['question']}\nContext: {context}"

def extract_answer(example, task):
    if task == "mmlu":
        return example["answer"]
    elif task == "mmlupro":
        return example["answer_index"]
    elif task == "humaneval":
        return example["canonical_solution"]
    else:
        return example["answers"]["text"][0] if example["answers"]["text"] else ""

def inference(args):
    if EXPERT_TYPES == 'hosted':
        print(API_KEYS)
        print(BASE_URLS)
        models = [
            OpenAI(
                api_key=api_key,
                base_url=base_url
            ) for _, api_key, base_url in zip(MODEL_NAMES, API_KEYS, BASE_URLS)
        ]
        tokenizers = [None] * len(MODEL_NAMES)
    else:
        models = [load_model_quantized(name) if USE_QUANTIZED else AutoModelForCausalLM.from_pretrained(name).to(DEVICE) for name in MODEL_NAMES]
        tokenizers = [AutoTokenizer.from_pretrained(name) for name in MODEL_NAMES]
        for t in tokenizers:
            t.pad_token = t.eos_token
            t.padding_side = 'left'

    shared_tokenizer = AutoTokenizer.from_pretrained(SHARED_ENCODER_NAME)
    shared_encoder = AutoModel.from_pretrained(SHARED_ENCODER_NAME).to(DEVICE)

    if args.task == "mmlu":
        dataset = load_dataset(args.dataset, 'all', split="validation[:]")
    elif args.task == "gsm8k":
        dataset = load_dataset(args.dataset, 'main', split="test[:]")
    else:
        dataset = load_dataset(args.dataset, split="test[:]")

    controller = CollaborationController(num_models=len(models), max_seq_len=len(models)-1, use_cosine_bias=True).to(DEVICE)
    controller, checkpoint = CollaborationController.from_saved_state(args.checkpoint)
    controller.to(DEVICE)
    logger.info(f"Loading controller trained for {checkpoint['epoch']} epochs")
    logger.info(f"Current k: {controller.current_k}")
    logger.info(f"Current tau: {controller.tau:.4f}")
    logger.info(f"Selection quality: {controller.selection_quality.item():.4f}")

    optimizer = torch.optim.Adam(controller.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state'])

    final_outputs = []
    gold_answers = []
    all_selected_models = []
    all_stopping_points = []
    prompts = []

    selection_logger = SelectionLogger(MODEL_NAMES)

    for step in tqdm(range(0, len(dataset), args.batch_size)):
        batch = dataset.select(range(step, min(step + args.batch_size, len(dataset))))

        for idx in range(len(batch)):
            base = ''
            if args.task == "squad":
                q = dataset[step + idx]["question"]
                ctx = dataset[step + idx].get("context", "")
                prompt = base + f"Question: {q}\nContext: {ctx}\nAnswer: "
                answer = dataset[step + idx]["answers"]["text"][0] if dataset[step + idx]["answers"]["text"] else "No Answer"
            elif args.task == "mmlu":
                prompt = format_prompt(dataset[step + idx], task="mmlu")
                _idx = dataset[step + idx]["answer"]
                _gold = dataset[step + idx]["choices"][_idx]
                answer = ["(A)", "(B)", "(C)", "(D)"][_idx] + " " + _gold
            elif args.task == "mmlupro":
                prompt = format_prompt(dataset[step + idx], task="mmlupro")
                _idx = dataset[step + idx]["answer_index"]
                _gold = dataset[step + idx]["options"][_idx]
                answer = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)"][_idx] + " " + _gold
            elif args.task == "humaneval":
                prompt = format_prompt(dataset[step + idx], task="humaneval")
                answer = dataset[step + idx]["canonical_solution"]
            elif args.task == "gsm8k":
                prompt = format_prompt(dataset[step + idx], task="gsm8k")
                answer = dataset[step + idx]["answer"].split('####')[-1].strip()
            else:
                raise ValueError("Unsupported task")
            prompts.append(prompt)
            print(prompt, answer)
            gold_answers.append(answer)

        outputs, selection_info, selections, stopping_points = run_batch_collaborative_inference_selected(
            prompts=prompts[step:step + args.batch_size],
            task=args.task,
            controller=controller,
            models=models,
            model_tokenizers=tokenizers,
            tokenizer=shared_tokenizer,
            encoder=shared_encoder,
            device=DEVICE,
            top_k=controller.current_k 
        )

        selection_logger.log_batch(
            batch_indices=range(step, min(step + args.batch_size, len(dataset))),
            selection_probs=selection_info['selection_probs'],
            selected_indices=selection_info['selected_indices'],
            phase='inference'
        )

        logger.info(outputs)
        final_outputs.extend(outputs)
        all_selected_models.extend(selections)
        if stopping_points is not None:
            all_stopping_points.extend(stopping_points)
        torch.cuda.empty_cache()

    print(len(final_outputs), len(gold_answers), len(prompts))

    avg_models_used = sum(len(m) for m in all_selected_models)/len(all_selected_models)
    logger.info(f"Average models used: {avg_models_used:.2f}")
    if all_stopping_points:
        avg_stopping_point = sum(all_stopping_points)/len(all_stopping_points)
        logger.info(f"Average stopping point: {avg_stopping_point:.2f}")
        logger.info(f"Early stopping rate: {len(all_stopping_points)/len(all_selected_models):.1%}")

    stats = selection_logger.get_selection_stats()
    logger.info("Model Selection Statistics:")
    for name in MODEL_NAMES:
        logger.info(f"{name}: Selected {stats['total_selections'][name]} times | Avg prob: {stats['average_prob'][name]:.3f}")
    print("Inference completed. Saving results...")

    timestamp_unix = int(time.time())
    csv_file = os.path.join(args.output_dir, f"{args.task}_predictions.csv.{timestamp_unix}")

    with open(csv_file, mode="w", encoding="utf-8") as file:
        writer = csv.writer(file)
        if args.task in ['mmlu', 'squad', 'gsm8k']:
            writer.writerow(["Context", "Question", "Gold Answer", "Prediction", "Selections"])
        elif args.task == 'humaneval':
            writer.writerow(["task_id", "completion", "canonical", "selections"])
        for i, prompt in enumerate(prompts):
            if args.task == "squad":
                ctx = dataset[i].get("context", "")
                q = dataset[i].get("question", "")
            elif args.task == "mmlu":
                ctx = prompt 
                q = dataset[i].get("question", "")
            elif args.task == "mmlupro":
                ctx = prompt 
                q = dataset[i].get("question", "")
            elif args.task == "humaneval":
                ctx = prompt
                q = dataset[i].get("prompt", "")
            else:
                ctx, q = prompt, ""

            if args.task == 'humaneval':
                writer.writerow([dataset[i]['task_id'], final_outputs[i], gold_answers[i], selections[i] if i < len(selections) else ""])
            else:
                writer.writerow([ctx, q, gold_answers[i], final_outputs[i], selections[i] if i < len(selections) else ""])

def single_model_inference(args):
    model_name = args.model_name if hasattr(args, 'model_name') else MODEL_NAMES[0]
    logger.info(f"Running single model inference with: {model_name}")

    if EXPERT_TYPES == "hosted":
        tokenizer = None
        model = OpenAI(
            base_url="",
            api_key="API_KEY"
        )
    else:
        if USE_QUANTIZED:
            model = load_model_quantized(model_name)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'left'

    if args.task == "mmlu":
        dataset = load_dataset(args.dataset, 'all', split="validation[:]")
    elif args.task == "mmlupro":
        dataset = load_dataset(args.dataset, split="validation")
    elif args.task == "gsm8k":
        dataset = load_dataset(args.dataset, 'main', split="test[:]")
    else:
        dataset = load_dataset(args.dataset, split="test[:]")

    final_outputs = []
    gold_answers = []
    prompts = []

    for step in tqdm(range(0, len(dataset), args.batch_size)):
        batch = dataset.select(range(step, min(step + args.batch_size, len(dataset))))

        batch_prompts = []
        batch_answers = []

        for idx in range(len(batch)):
            if args.task == "squad":
                q = batch[idx]["question"]
                ctx = batch[idx].get("context", "")
                prompt = f"Question: {q}\nContext: {ctx}\nAnswer: "
                answer = batch[idx]["answers"]["text"][0] if batch[idx]["answers"]["text"] else "No Answer"
            elif args.task == "mmlu":
                prompt = format_prompt(batch[idx], task="mmlu")
                _idx = batch[idx]["answer"]
                _gold = batch[idx]["choices"][_idx]
                answer = ["(A)", "(B)", "(C)", "(D)"][_idx] + " " + _gold
            elif args.task == "mmlupro":
                prompt = format_prompt(dataset[step + idx], task="mmlupro")
                _idx = dataset[step + idx]["answer_index"]
                _gold = dataset[step + idx]["options"][_idx]
                answer = ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)", "(G)", "(H)", "(I)", "(J)"][_idx] + " " + _gold
            elif args.task == "humaneval":
                prompt = format_prompt(batch[idx], task="humaneval")
                answer = batch[idx]["canonical_solution"]
            elif args.task == "gsm8k":
                prompt = format_prompt(batch[idx], task="gsm8k")
                answer = batch[idx]["answer"].split('####')[-1].strip()
            else:
                raise ValueError("Unsupported task")

            batch_prompts.append(prompt)
            batch_answers.append(answer)

        outputs = generate_text(model, tokenizer, batch_prompts, args.model_name)

        final_outputs.extend(outputs)
        gold_answers.extend(batch_answers)
        prompts.extend(batch_prompts)

        if step == 0:
            for i in range(min(3, len(outputs))):
                logger.info(f"Prompt: {batch_prompts[i][:200]}...")
                logger.info(f"Gold Answer: {batch_answers[i]}")
                logger.info(f"Model Output: {outputs[i]}")
                logger.info("="*50)

    timestamp_unix = int(time.time())
    csv_file = os.path.join("tmp", f"{args.task}_{model_name.replace('/','_')}_predictions.csv.{timestamp_unix}")

    with open(csv_file, mode="w", encoding="utf-8") as file:
        writer = csv.writer(file)
        if args.task in ['mmlu', 'squad', 'gsm8k']:
            writer.writerow(["Context", "Question", "Gold Answer", "Prediction"])
        elif args.task == 'humaneval':
            writer.writerow(["task_id", "completion", "canonical"])

        for i, prompt in enumerate(prompts):
            if args.task == "squad":
                ctx = dataset[i].get("context", "")
                q = dataset[i].get("question", "")
            elif args.task == "mmlu":
                ctx = prompt
                q = dataset[i].get("question", "")
            elif args.task == "mmlupro":
                ctx = prompt
                q = dataset[i].get("question", "")
            elif args.task == "humaneval":
                ctx = prompt
                q = dataset[i].get("prompt", "")
            else:
                ctx, q = prompt, ""

            if args.task == 'humaneval':
                writer.writerow([dataset[i]['task_id'], final_outputs[i], gold_answers[i]])
            else:
                writer.writerow([ctx, q, gold_answers[i], final_outputs[i]])

    logger.info(f"Results saved to {csv_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", choices=["train", "inference", "single"], required=True)
    parser.add_argument("--checkpoint", default="checkpoints/collab_matrix.pt")
    parser.add_argument("--output_dir", default="checkpoints/")
    parser.add_argument("--strategy", choices=["top1", "weighted_vote", "refinement_chain"], default="top1")
    parser.add_argument("--dataset", type=str, default="squad")
    parser.add_argument("--task", type=str, choices=["squad", "mmlu", "humaneval", "gsm8k", "mmlupro"], default="squad")
    parser.add_argument("--refine_with", choices=["full", "last"], default="last",
                    help="Mode for refinement chain: full history or just last model output")
    parser.add_argument("--logfile", type=str, default="multi_slm_collab.log")
    parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--lambda_symm", type=float, default=0.1)
    parser.add_argument("--lambda_sparse", type=float, default=0.01)
    parser.add_argument("--seq_len", type=int, default=4, help="Initial length of model sequence")
    parser.add_argument("--num_experts", type=int, default=3, help="Number of expert models")
    parser.add_argument("--model_name", type=str, help="Model name for inference")
    parser.add_argument("--hosted_config", type=str, help="BaseURLs and API Keys")
    args = parser.parse_args()

    API_KEYS, BASE_URLS = eval(args.hosted_config)

    file_handler = logging.FileHandler(args.logfile)
    file_handler.setLevel(logging.DEBUG)
    formatter = coloredlogs.ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger = logging.getLogger("multi_slm_collab")
    logger.handlers.clear()
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    coloredlogs.install()

    args.batch_size = 2
    args.lambda_symm = 0.1
    args.lambda_sparse = 0.01

    MODEL_NAMES = eval(args.model_name)
    args.seq_len = len(MODEL_NAMES) - 1

    if args.mode == "train":
        train_controller(args)
    elif args.mode == "inference":
        inference(args)
    elif args.mode == "single":
        single_model_inference(args)
