import pandas as pd
import numpy as np
# from fastchat.conversation import Conversation, SeparatorStyle
from transformers import AutoTokenizer
import os
import sys
import srsly
import fire
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import time
import json
from utils import CosineSimilarityNet,RiskControlledCalibrator_imp2

import seaborn as sns
import pandas as pd

def get_question_type(sample):
    if 'question_id' in sample:
        return 'WikiQA'
    elif 'choices' in sample:
        return 'MMLU'
    else:
        return 'BoolQ'

def analyze_routing_by_type(routes, data):
    type_routing = {
        'WikiQA': {'large_model': 0, 'small_model': 0, 'human': 0, 'total': 0},
        'MMLU': {'large_model': 0, 'small_model': 0, 'human': 0, 'total': 0},
        'BoolQ': {'large_model': 0, 'small_model': 0, 'human': 0, 'total': 0}
    }
    
    for route, sample in zip(routes, data):
        qtype = get_question_type(sample)
        type_routing[qtype][route] += 1
        type_routing[qtype]['total'] += 1
    
    print("\nRouting Analysis by Question Type:")
    for qtype, stats in type_routing.items():
        total = stats['total']
        if total == 0:
            continue
        print(f"\n{qtype} Questions (Total: {total}):")
        for route in ['large_model', 'small_model', 'human']:
            count = stats[route]
            percentage = (count / total * 100) if total > 0 else 0
            print(f"  {route}: {count} ({percentage:.1f}%)")
    
    return type_routing

def main(
        model_path = "/home/personalized_RAG/model/qwen2.5-1.5b",
        test_file = "/home/data/test_safe_token.json",
        out_path = "/home/generate_embd/qwen2.5-1.5b_emb_res_test.pt",
        safe_model = "/root/safe_proj/gen_embd/qwen2.5-0.5b_emb_res_20000_cosine_similarity_net.pt",
        save_path = "/home/generate_embd/safe_model_pred.json",
        safety_weight= 0.4,
        efficiency_weight= 0.3,
        time_weight= 0.3,
        val_size= None,
        batch_size=256,
        ):
    system_prompt = "You are now a helpful personal AI assistant."
    def format_with_qwen(data):
        format_input = []
        for sample in data:
            user_input = sample["question"]
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_input}
                ]
            format_input.append(messages)
        return format_input

    model_name = model_path.split('/')[-1]
    model = AutoModelForCausalLM.from_pretrained(model_path,
                                                torch_dtype=torch.float16,
                                                device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    data = srsly.read_json(test_file)
    text_ls = format_with_qwen(data)
    emb_res = []
    label_res = []

    for i in tqdm(range(0, len(text_ls), batch_size), desc="Processing batches"):
        batch_text = text_ls[i:i + batch_size]
        user_input = tokenizer.apply_chat_template(
                                        batch_text,
                                        tokenize=False,
                                        add_generation_prompt=True
                                    )
        input_ids = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True, max_length=768)['input_ids'].to(model.device)

        with torch.no_grad():
            input_embeddings = model.get_input_embeddings()(input_ids)
            batch_emb_res = torch.mean(input_embeddings, dim=1).cpu().numpy()
            emb_res.extend(batch_emb_res)

    for item in data:
        label_res.append(item['safe'])

    input_data = []
    labels = []
    for item in emb_res:
        input_data.append(item)
    labels = label_res

    input_data = [torch.tensor(item, dtype=torch.float32) if isinstance(item, np.ndarray) else item for item in input_data]
    input_data = torch.stack(input_data)
    labels = torch.tensor(labels, dtype=torch.long)

    safe_model_net = CosineSimilarityNet(input_dim=emb_res[0].shape[0])
    checkpoint = torch.load(safe_model)
    safe_model_net.load_state_dict(checkpoint)
    safe_model_net.eval()

    predictions = []

    #print(f"Total predictions: {len(predictions)}")
    print(f"Input data shape: {input_data.shape}")
    for i in tqdm(range(0, len(input_data), batch_size), desc="Making predictions"):
        batch_data = input_data[i:i + batch_size]
        with torch.no_grad():
            safe_model_predict = safe_model_net(batch_data)
            #batch_preds = (safe_model_predict[:, 1] > 0.5).int().cpu().numpy()
            batch_preds = safe_model_predict[:, 1].cpu().numpy() 
            predictions.extend(batch_preds)

    print(f"Total predictions: {len(predictions)}")
    print(f"Sample of predictions: {predictions[:5]}")
    print(f"Predictions stats: min={min(predictions)}, max={max(predictions)}, mean={np.mean(predictions)}")

    calibrator = RiskControlledCalibrator_imp2(
        alpha=0.05,
        delta=0.02,
        num_thresholds=100,
        safety_weight=0.4,
        efficiency_weight=0.2,
        time_weight=0.4,
        temperature=7.5, 
        shift=0.1       
    )

    n_samples = len(predictions)
    val_size = val_size if val_size is not None else 3000
    
    val_predictions = np.array(predictions[:val_size])
    val_labels = np.array(labels[:val_size])

    print(f"Validation set size: {len(val_predictions)}")
    
    tau1, tau2 = calibrator.calibrate(val_predictions, val_labels)

    large_model_mask, small_model_mask, human_mask = calibrator.predict(np.array(predictions))
    
    routes = []
    for i in range(len(predictions)):
        if large_model_mask[i]:
            routes.append('large_model')
        elif small_model_mask[i]:
            routes.append('small_model')
        else:
            routes.append('human')
    
    print(f"Routes generated: {len(routes)}")
    print(f"Sample of routes: {routes[:5]}")
    
    risk_metrics = []
    for orig_pred, transformed_pred in zip(predictions, calibrator.transform_predictions(np.array(predictions))):
        risk_metric = {
            'original_risk': float(orig_pred),
            'transformed_risk': float(transformed_pred),
            'risk_delta': float(transformed_pred - orig_pred)
        }
        risk_metrics.append(risk_metric)

    for idx, (pred, route, risk_metric) in enumerate(zip(predictions, routes, risk_metrics)):
        data[idx]['pred'] = float(pred)
        data[idx]['route'] = str(route)
        data[idx]['confidence'] = 'high' if route == 'large_model' else ('medium' if route == 'small_model' else 'low')
        data[idx]['risk_assessment'] = risk_metric

    routing_analysis = analyze_routing_by_type(routes, data)
    
    calibration_info = {
        "thresholds": {
            "tau1": float(tau1),
            "tau2": float(tau2)
        },
        "routing_stats": {
            "large_model_count": int(sum(r == 'large_model' for r in routes)),
            "small_model_count": int(sum(r == 'small_model' for r in routes)),
            "human_review_count": int(sum(r == 'human' for r in routes))
        },
        "risk_distribution": {
            "original": {
                "mean": float(np.mean([r['original_risk'] for r in risk_metrics])),
                "std": float(np.std([r['original_risk'] for r in risk_metrics])),
                "min": float(min([r['original_risk'] for r in risk_metrics])),
                "max": float(max([r['original_risk'] for r in risk_metrics]))
            },
            "transformed": {
                "mean": float(np.mean([r['transformed_risk'] for r in risk_metrics])),
                "std": float(np.std([r['transformed_risk'] for r in risk_metrics])),
                "min": float(min([r['transformed_risk'] for r in risk_metrics])),
                "max": float(max([r['transformed_risk'] for r in risk_metrics]))
            }
        },
        "risk_thresholds": {
            "high_risk": float(tau2),
            "medium_risk": float(tau1)
        },
        "question_type_analysis": {
            qtype: {
                'route_distribution': {
                    route: count
                    for route, count in stats.items()
                    if route != 'total'
                },
                'total_questions': stats['total']
            }
            for qtype, stats in routing_analysis.items()
        }
    }
    route_risks = {
        'large_model': [],
        'small_model': [],
        'human': []
    }
    
    for route, risk in zip(routes, risk_metrics):
        route_risks[route].append(risk['transformed_risk'])
    
    calibration_info['route_risk_analysis'] = {
        route: {
            'mean_risk': float(np.mean(risks)) if risks else 0,
            'std_risk': float(np.std(risks)) if risks else 0,
            'sample_count': len(risks)
        }
        for route, risks in route_risks.items()
    }

    output_data = {
        "calibration_info": calibration_info,
        "predictions": data
    }
    
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)
    
    print("\nRisk Analysis:")
    print(f"Original Risk Distribution:")
    print(f"  Mean: {calibration_info['risk_distribution']['original']['mean']:.4f}")
    print(f"  Std:  {calibration_info['risk_distribution']['original']['std']:.4f}")
    print(f"\nTransformed Risk Distribution:")
    print(f"  Mean: {calibration_info['risk_distribution']['transformed']['mean']:.4f}")
    print(f"  Std:  {calibration_info['risk_distribution']['transformed']['std']:.4f}")
    
    print("\nRoute-wise Risk Analysis:")
    for route, stats in calibration_info['route_risk_analysis'].items():
        print(f"{route}:")
        print(f"  Mean Risk: {stats['mean_risk']:.4f}")
        print(f"  Std Risk:  {stats['std_risk']:.4f}")
        print(f"  Samples:   {stats['sample_count']}")

if __name__ == "__main__":
    fire.Fire(main)