import torch
import torch.nn as nn
import numpy as np
import argparse
import os
import csv
import sys
import random
import string
import warnings
from datetime import datetime
from tqdm import tqdm
from sklearn.cluster import SpectralClustering
from sklearn.preprocessing import StandardScaler, LabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from transformers import T5Tokenizer, T5EncoderModel
from datasets import load_dataset

# 抑制警告
warnings.filterwarnings("ignore")

# ==============================================================================
# 模块 1: 噪声注入工具
# ==============================================================================
class NoiseManager:
    @staticmethod
    def word_swap(text, n_swaps=1):
        words = text.split()
        if len(words) < 2: return text
        for _ in range(int(n_swaps)):
            if len(words) < 2: break
            idx = random.randint(0, len(words) - 2)
            words[idx], words[idx+1] = words[idx+1], words[idx]
        return " ".join(words)

    @staticmethod
    def char_noise(text, n_errors=1):
        chars = list(text)
        if len(chars) < 1: return text
        for _ in range(int(n_errors)):
            idx = random.randint(0, len(chars) - 1)
            if chars[idx] != ' ': 
                chars[idx] = random.choice(string.ascii_letters)
        return "".join(chars)

    @staticmethod
    def add_gaussian_noise(X, std):
        if std <= 0: return X
        noise = np.random.normal(0, std, X.shape)
        return X + noise

# ==============================================================================
# 模块 2: 数据管理 (新增 max_samples 控制)
# ==============================================================================
class DataManager:
    def __init__(self, model_name="t5-small", device="cuda"):
        self.model_name = model_name
        self.device = device if torch.cuda.is_available() else "cpu"
        self.tokenizer = None
        self.model = None

    def _load_model(self):
        if self.model is None:
            print(f"[*] Loading T5 Model: {self.model_name} on {self.device}...")
            self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
            self.model = T5EncoderModel.from_pretrained(self.model_name).to(self.device).eval()

    def get_path(self, dataset, split, n_type, n_level, max_samples):
        os.makedirs("./data_cache_robust", exist_ok=True)
        suffix = "clean"
        if n_type in ['word_swap', 'char_noise'] and n_level > 0:
            suffix = f"{n_type}_{int(n_level)}"
        
        # [修改] 文件名中加入 max_samples 标识，防止加载到全量数据导致内存爆
        # 如果 max_samples 是 None 或 -1，表示全量
        limit_str = f"_limit{max_samples}" if max_samples > 0 else ""
        return os.path.join("./data_cache_robust", f"{dataset}_{split}_{suffix}{limit_str}_{self.model_name}.pt")

    def ensure_data(self, dataset, split, n_type='clean', n_level=0, max_samples=-1):
        file_n_type = n_type if n_type in ['word_swap', 'char_noise'] else 'clean'
        if n_level == 0: file_n_type = 'clean'
        
        path = self.get_path(dataset, split, file_n_type, n_level, max_samples)
        
        if not os.path.exists(path):
            self._generate(dataset, split, file_n_type, n_level, path, max_samples)
        return path

    def _generate(self, dname, split, n_type, n_level, path, max_samples):
        self._load_model()
        print(f"[*] Generating {dname} [{split}] Noise={n_type}({n_level}) Limit={max_samples}...")
        
        if dname == 'ag_news': d = load_dataset(dname, split=split)
        else: d = load_dataset("glue", dname, split=split)
        
        # [修改] 如果指定了 max_samples，先切分数据集，避免遍历整个大数据集
        if max_samples > 0 and len(d) > max_samples:
            print(f"    -> Subsampling from {len(d)} to {max_samples}")
            # 使用 shuffle 保证随机性，固定 seed 保证可复现
            d = d.shuffle(seed=42).select(range(max_samples))

        acts, labs, batch_txt, batch_lbl = [], [], [], []
        
        for i, item in enumerate(tqdm(d)):
            if dname == 'ag_news': txt = item['text']
            elif dname in ['mnli', 'qnli']: txt = f"{dname} hypothesis: {item['hypothesis']} premise: {item['premise']}"
            else: txt = item['sentence']
            
            if n_type == 'word_swap': txt = NoiseManager.word_swap(txt, n_level)
            elif n_type == 'char_noise': txt = NoiseManager.char_noise(txt, n_level)
            
            batch_txt.append(txt)
            batch_lbl.append(item['label'])
            
            if len(batch_txt) >= 32 or i == len(d)-1:
                inp = self.tokenizer(batch_txt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(self.device)
                with torch.no_grad(): 
                    acts.append(self.model(inp.input_ids, inp.attention_mask).last_hidden_state[:,0,:].cpu())
                labs.extend(batch_lbl)
                batch_txt, batch_lbl = [], []
        
        torch.save({'activations': torch.cat(acts, 0), 'labels': torch.tensor(labs)}, path)
        print(f"[*] Saved to {path}")

def load_pt(path):
    data = torch.load(path, weights_only=False)
    return data['activations'].numpy(), data['labels'].numpy()

# ==============================================================================
# 模块 3: Learned Router
# ==============================================================================
class LearnedRouter:
    def __init__(self, n_experts, top_k=1):
        self.n_experts = n_experts
        self.top_k = top_k
        self.gate_model = None

    def fit(self, X_train, y_train, experts, feature_masks):
        n_samples = X_train.shape[0]
        losses = np.full((n_samples, self.n_experts), np.inf)
        for i, clf in enumerate(experts):
            if clf:
                try:
                    probs = clf.predict_proba(X_train[:, feature_masks[i]])
                    true_probs = probs[np.arange(n_samples), y_train.astype(int)]
                    losses[:, i] = -np.log(true_probs + 1e-9)
                except: pass
        best_expert_labels = np.argmin(losses, axis=1)
        self.gate_model = LogisticRegression(C=1.0, solver='lbfgs', multi_class='multinomial', max_iter=200, random_state=42)
        self.gate_model.fit(X_train, best_expert_labels)

    def get_gates(self, X):
        logits = self.gate_model.predict_proba(X)
        n_samples = logits.shape[0]
        top_k_indices = np.argsort(logits, axis=1)[:, -self.top_k:]
        weights = np.zeros_like(logits)
        for i in range(n_samples):
            chosen = top_k_indices[i]
            vals = logits[i, chosen]
            weights[i, chosen] = vals / (np.sum(vals) + 1e-9)
        return weights

# ==============================================================================
# 模块 4: MoE Probe
# ==============================================================================
class MoEProbe:
    def __init__(self, n_experts, feature_labels, class_weight='balanced'):
        self.n_experts = n_experts
        self.feature_labels = feature_labels
        self.class_weight = class_weight
        self.experts = []
        self.feature_masks = []
        self.router = None 

    def train(self, X_train, y_train, C_reg, top_k):
        self.experts, self.feature_masks = [], []
        self.router = LearnedRouter(self.n_experts, top_k)
        for i in range(self.n_experts):
            mask = np.where(self.feature_labels == i)[0]
            self.feature_masks.append(mask)
            if len(mask) == 0: self.experts.append(None); continue
            clf = LogisticRegression(penalty='l2', C=C_reg, solver='liblinear', 
                                     class_weight=self.class_weight, random_state=42)
            clf.fit(X_train[:, mask], y_train)
            self.experts.append(clf)
        self.router.fit(X_train, y_train, self.experts, self.feature_masks)

    def predict(self, X_test):
        weights = self.router.get_gates(X_test)
        valid = next(e for e in self.experts if e)
        n_classes = len(valid.classes_)
        final_probs = np.zeros((X_test.shape[0], n_classes))
        for i, clf in enumerate(self.experts):
            if clf and np.sum(weights[:, i]) > 1e-9:
                final_probs += clf.predict_proba(X_test[:, self.feature_masks[i]]) * weights[:, i:i+1]
        return np.argmax(final_probs, axis=1)

# ==============================================================================
# 模块 5: 特征选择器
# ==============================================================================
class BaseSelector:
    def fit_predict(self, X, y, n_clusters, **kwargs): raise NotImplementedError

class ConstraintSpectralSelector(BaseSelector):
    def fit_predict(self, X, y, n_clusters, **kwargs):
        print("[Selector] Fisher Score + Spectral...")
        unique_y = np.unique(y)
        means = np.array([X[y == c].mean(axis=0) for c in unique_y])
        global_mean = X.mean(axis=0)
        numerator = np.sum([np.sum(y == c) * (means[i] - global_mean)**2 for i, c in enumerate(unique_y)], axis=0)
        denominator = np.sum([np.sum(y == c) * X[y == c].var(axis=0) for c in unique_y], axis=0) + 1e-9
        fisher_scores = numerator / denominator
        base_affinity = np.maximum(cosine_similarity(X.T), 0)
        fisher_weight = np.exp(-np.abs(fisher_scores[:, None] - fisher_scores[None, :]))
        return SpectralClustering(n_clusters=n_clusters, affinity='precomputed', random_state=42, n_jobs=-1).fit_predict(base_affinity * fisher_weight)

class ProxyProbeSelector(BaseSelector):
    def fit_predict(self, X, y, n_clusters, class_weight='balanced', **kwargs):
        print("[Selector] Lasso Proxy...")
        probe = LogisticRegression(penalty='l1', C=1.0, solver='liblinear', random_state=42, class_weight=class_weight)
        probe.fit(X, y)
        return SpectralClustering(n_clusters=n_clusters, affinity='precomputed', random_state=42, n_jobs=-1).fit_predict(np.maximum(cosine_similarity(probe.coef_.T + 1e-9), 0))

SELECTOR_MAP = {'constraint': ConstraintSpectralSelector, 'proxy': ProxyProbeSelector}

# ==============================================================================
# 模块 6: 主流程
# ==============================================================================
def get_metrics(y_true, y_pred, dataset):
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='weighted')
    score = f1 if dataset == 'cola' else acc
    return score, acc, f1

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--model", default="t5-small")
    p.add_argument("--dataset", required=True)
    p.add_argument("--selector", required=True, choices=['lasso', 'proxy', 'constraint'])
    p.add_argument("--n_experts", type=int, default=8)
    p.add_argument("--top_k", type=int, default=2)
    p.add_argument("--c_list", type=float, nargs='+', default=[0.1, 1.0])
    
    # 噪声配置
    p.add_argument("--noise_type", default="gaussian", choices=["gaussian", "word_swap", "char_noise"])
    p.add_argument("--noise_level", type=float, default=0.5)
    
    # [新增] 数据量限制
    p.add_argument("--max_samples", type=int, default=30000, help="Max samples for train/test to avoid OOM. -1 for all.")
    
    p.add_argument("--output_csv", default="robustness_results.csv")
    p.add_argument("--split_from_train", action="store_true")
    args = p.parse_args()

    # 1. 准备数据
    dm = DataManager(model_name=args.model)
    if args.dataset == 'mnli': tr_s, te_s = 'train', 'validation_matched'
    elif args.dataset == 'ag_news': tr_s, te_s = 'train', 'test'
    else: tr_s, te_s = 'train', 'validation'

    # 1.1 加载 Clean Train Data
    # 传入 max_samples
    tr_path = dm.ensure_data(args.dataset, tr_s, 'clean', 0, max_samples=args.max_samples)
    
    # 1.2 准备 Test Data
    te_path_clean = dm.ensure_data(args.dataset, te_s, 'clean', 0, max_samples=args.max_samples)
    
    te_path_noisy = te_path_clean
    if args.noise_type in ['word_swap', 'char_noise'] and args.noise_level > 0:
        te_path_noisy = dm.ensure_data(args.dataset, te_s, args.noise_type, args.noise_level, max_samples=args.max_samples)

    # 1.3 加载与标准化
    Xt_clean, yt = load_pt(tr_path)
    Xe_clean, ye = load_pt(te_path_clean)
    
    if args.split_from_train or args.dataset == 'sst2':
        # 如果是 SST2 或指定 split，我们从 Train 里切
        # 注意：这里会丢失原本的 Test set，如果想用 input noise，这种模式稍微有点问题
        # 但为了不崩，我们假设对于 SST2 这种小数据集，不限制 max_samples，或者直接切分
        Xt_clean, _, yt, _ = train_test_split(Xt_clean, yt, test_size=0.2, random_state=42, stratify=yt)
        # 对于 SST2 鲁棒性 Input Noise 测试，最好不要 split，而是直接用 validation
        # 但代码逻辑保持兼容
    
    scaler = StandardScaler().fit(Xt_clean)
    Xt_clean = scaler.transform(Xt_clean)
    Xe_clean = scaler.transform(Xe_clean)
    
    # 准备 Xe_noisy
    if args.noise_type == 'gaussian':
        Xe_noisy = NoiseManager.add_gaussian_noise(Xe_clean, args.noise_level)
    else:
        if args.noise_level == 0:
            Xe_noisy = Xe_clean
        else:
            raw_noisy, ye_noisy = load_pt(te_path_noisy)
            if not np.array_equal(ye, ye_noisy):
                min_len = min(len(ye), len(ye_noisy))
                ye = ye[:min_len]
                Xe_clean = Xe_clean[:min_len]
                raw_noisy = raw_noisy[:min_len]
            Xe_noisy = scaler.transform(raw_noisy)

    cw = 'balanced' if args.dataset == 'cola' else None

    # --- 2. Lasso Baseline ---
    if args.selector == 'lasso':
        print(f"\n=== Lasso | Noise: {args.noise_type} {args.noise_level} ===")
        for c in args.c_list:
            m = LogisticRegression(penalty='l1', C=c, solver='liblinear', class_weight=cw, random_state=42).fit(Xt_clean, yt)
            s_clean, _, _ = get_metrics(ye, m.predict(Xe_clean), args.dataset)
            s_noisy, _, _ = get_metrics(ye, m.predict(Xe_noisy), args.dataset)
            drop = s_clean - s_noisy
            print(f"C={c} | Clean: {s_clean:.4f} | Noisy: {s_noisy:.4f} | Drop: {drop:.4f}")
            with open(args.output_csv, 'a', newline='') as f:
                csv.DictWriter(f, fieldnames=['ts','dataset','method','n_experts','top_k','C','noise_type','noise_lvl','clean_score','noisy_score','perf_drop']).writerow({
                    'ts': datetime.now().isoformat(), 'dataset': args.dataset, 'method': 'lasso', 
                    'n_experts': 1, 'top_k': 1, 'C': c, 'noise_type': args.noise_type, 'noise_lvl': args.noise_level,
                    'clean_score': s_clean, 'noisy_score': s_noisy, 'perf_drop': drop
                })
        return

    # --- 3. MoE Pipeline ---
    # 缓存文件名也加上 limit
    limit_str = f"_limit{args.max_samples}" if args.max_samples > 0 else ""
    cache = f"./data_cache_robust/clusters_{args.dataset}_{args.selector}_k{args.n_experts}{limit_str}.pt"
    
    if os.path.exists(cache):
        labels = torch.load(cache, weights_only=False)
    else:
        print(f"[*] Clustering features ({args.selector})...")
        labels = SELECTOR_MAP[args.selector]().fit_predict(Xt_clean, yt, args.n_experts, class_weight=cw)
        torch.save(labels, cache)

    moe = MoEProbe(args.n_experts, labels, cw)
    
    for c in args.c_list:
        print(f"\n=== MoE ({args.selector}) | E={args.n_experts} K={args.top_k} | Noise: {args.noise_type} {args.noise_level} ===")
        moe.train(Xt_clean, yt, c, args.top_k)
        s_clean, _, _ = get_metrics(ye, moe.predict(Xe_clean), args.dataset)
        s_noisy, _, _ = get_metrics(ye, moe.predict(Xe_noisy), args.dataset)
        drop = s_clean - s_noisy
        print(f"C={c} | Clean: {s_clean:.4f} | Noisy: {s_noisy:.4f} | Drop: {drop:.4f}")
        with open(args.output_csv, 'a', newline='') as f:
            csv.DictWriter(f, fieldnames=['ts','dataset','method','n_experts','top_k','C','noise_type','noise_lvl','clean_score','noisy_score','perf_drop']).writerow({
                'ts': datetime.now().isoformat(), 'dataset': args.dataset, 'method': f"moe_{args.selector}", 
                'n_experts': args.n_experts, 'top_k': args.top_k, 'C': c, 'noise_type': args.noise_type, 'noise_lvl': args.noise_level,
                'clean_score': s_clean, 'noisy_score': s_noisy, 'perf_drop': drop
            })

if __name__ == "__main__": main()