import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import numpy as np
import tqdm
import pickle
import transformers
import argparse
import torch
import pandas as pd

from typing import List
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
sys.path.insert(0, '.')
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from constants import MODEL_IDENTIFIER, HF_DATASETS_CACHE, \
    HF_MODELS_CACHE, DATASET_PATH, STOP_SEQUENCES, \
    EMBEDDING_MODEL_IDENTIFIER, INDENTIFIER2NAME, CLASS_MODEL_IDENTIFIER


def create_classification_dataloader(
    batch_size: int,
    inputs_: List[str],
    labels: List[int],
    question_ids: List[str],
    tokenizer: AutoTokenizer,
    **tokenizer_kwargs,
) -> DataLoader:
    data = []
    for i, (input_, label, question_id) in enumerate(zip(inputs_, labels, question_ids)):
        tokenized_input = tokenizer(input_, **tokenizer_kwargs)
        data.append(
            {
                **tokenized_input,
                'question_id': question_id,
                "label": label,
            }
        )
    data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)
    return data_loader


def loop_dataloader(dataloader: DataLoader):
    while True:
        for batch in dataloader:
            yield batch


def flatten(nested_list):
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):
            flat_list.extend(flatten(item))
        else:
            flat_list.append(item)
    return flat_list


def em_training(args, num_em_steps=3, lambda_calib=3.0):
    model_name = INDENTIFIER2NAME[args.model]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset_path = os.path.join(args.dataset_path, args.dataset, model_name, 'data_with_answer.pkl')

    with open(dataset_path, 'rb') as f:
        dataset = pickle.load(f)

    # 只用训练集做 EM
    train_data_full = flatten(dataset['train'])

    # ===== 第一步：在 trainset 上做 0/1 均匀下采样（下采样多数类） =====
    idx_pos = [i for i, d in enumerate(train_data_full) if int(d['correctness'][0]) == 1]
    idx_neg = [i for i, d in enumerate(train_data_full) if int(d['correctness'][0]) == 0]
    n = min(len(idx_pos), len(idx_neg))
    rng = np.random.RandomState(42)
    idx_pos_bal = rng.choice(idx_pos, n, replace=False)
    idx_neg_bal = rng.choice(idx_neg, n, replace=False)
    keep_idx = np.concatenate([idx_pos_bal, idx_neg_bal])
    rng.shuffle(keep_idx)

    # 均衡后的训练样本集合（后续 E/M 步都只在这一批样本上进行）
    train_data = [train_data_full[i] for i in keep_idx]

    print(f"[Balance] Balanced train set size: {len(train_data)} | pos: {n} | neg: {n}")

    # 为快速回写建立 id -> data 映射
    id2data = {d['id'][0]: d for d in train_data}

    # ===== 构造文本输入/标签（基于均衡后的 train_data） =====
    calibration_tokenizer = AutoTokenizer.from_pretrained(args.classifier)
    calibration_config = AutoConfig.from_pretrained(args.classifier)

    inputs, labels, question_ids = [], [], []
    for d in train_data:
        input_ = d['question'][0] + f' [SEP] {d["model_answer"]}'
        inputs.append(input_)
        labels.append(int(d['correctness'][0]))
        question_ids.append(d['id'][0])

    k = 10
    kf = KFold(n_splits=k, shuffle=True, random_state=42)

    # 初始化 soft confidence target（第一轮不用）
    soft_conf_labels = None

    for em_step in range(num_em_steps):
        print(f"\n======= EM Step {em_step+1}/{num_em_steps} =======")

        # 下面这些容器的内容，完全来自“均衡后的 train_data”的 KFold OOF 预测
        all_conf, all_correctness, all_ids = [], [], []

        for split_idx, (train_ids, test_ids) in enumerate(kf.split(train_data)):
            print(f"\nSplit {split_idx+1}/{k}")

            # 模型初始化
            classifier_model = AutoModelForSequenceClassification.from_pretrained(args.classifier).to(device)

            # 均衡集上的 KFold train/test 划分
            x_train = [inputs[i] for i in train_ids]
            x_test  = [inputs[i] for i in test_ids]
            y_train = [labels[i] for i in train_ids]
            y_test  = [labels[i] for i in test_ids]
            qid_train = [question_ids[i] for i in train_ids]
            qid_test  = [question_ids[i] for i in test_ids]

            dataloaders = {
                'train': create_classification_dataloader(
                    batch_size=args.batch_size,
                    inputs_=x_train,
                    labels=y_train,
                    question_ids=qid_train,
                    tokenizer=calibration_tokenizer,
                    padding="max_length",
                    truncation=True,
                    max_length=calibration_config.max_position_embeddings,
                    return_tensors="pt",
                ),
                'test': create_classification_dataloader(
                    batch_size=args.batch_size,
                    inputs_=x_test,
                    labels=y_test,
                    question_ids=qid_test,
                    tokenizer=calibration_tokenizer,
                    padding="max_length",
                    truncation=True,
                    max_length=calibration_config.max_position_embeddings,
                    return_tensors="pt",
                ),
            }

            # 优化器和 scheduler
            iterations = args.epoch * len(dataloaders['train'])
            optimizer = optim.AdamW(classifier_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            lr_scheduler = transformers.get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=min(int(args.warmup_fraction * iterations), 100),
                num_training_steps=int(iterations * 1.1),
            )

            # ===== 训练（在均衡训练子集上） =====
            classifier_model.train()
            for i, batch in enumerate(loop_dataloader(dataloaders['train'])):
                if i >= iterations:
                    break
                input_ids = batch["input_ids"].squeeze(1).to(device)
                attention_mask = batch["attention_mask"].squeeze(1).to(device)
                gt = batch["label"].to(device)

                outputs = classifier_model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                ce_loss = F.cross_entropy(logits, gt)

                # 如果不是第一轮，加上 calibration loss（目标来自上一轮对“均衡集”的 OOF bin-acc）
                if soft_conf_labels is not None:
                    preds = torch.softmax(logits, dim=-1)[:, 1]  # 正类置信度
                    calib_targets = torch.tensor(
                        [soft_conf_labels[qid.item() if hasattr(qid, "item") else qid] for qid in batch['question_id']],
                        dtype=torch.float32, device=device
                    )
                    calib_loss = F.mse_loss(preds, calib_targets)
                    loss = ce_loss + lambda_calib * calib_loss
                else:
                    loss = ce_loss

                loss.backward()
                clip_grad_norm_(classifier_model.parameters(), max_norm=args.max_grad_norm)

                if (i + 1) % 10 == 0:
                    print(f"[Step {i+1}/{iterations}] CE: {ce_loss.item():.4f}, Loss: {loss.item():.4f}")

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            # ===== 收集当前折的 OOF 预测（在均衡验证子集上） =====
            classifier_model.eval()
            with torch.no_grad():
                # EVAL_T = 1.5  # 评估温度（全流程保持一致）
                for batch in dataloaders['test']:
                    input_ids = batch["input_ids"].squeeze(1).to(device)
                    attention_mask = batch["attention_mask"].squeeze(1).to(device)
                    gt = batch["label"].cpu().numpy()
                    qids = batch["question_id"]

                    outputs = classifier_model(input_ids, attention_mask=attention_mask)
                    probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()  # [B, 2]
                    conf_pos = probs[:, 1]

                    all_conf.extend(conf_pos.tolist())
                    all_correctness.extend(gt.tolist())
                    all_ids.extend([qid for qid in qids])

            del classifier_model

        # ===== M 步：只在“这次筛选出来的样本（均衡集）”上计算 bin-wise 目标 =====
        # all_ids/all_conf/all_correctness 就是均衡集的 OOF 预测
        df = pd.DataFrame({
            "id": all_ids,
            "y_pred": all_conf,
            "y": all_correctness
        })

        # 分箱：[0,1) 均分成 10 段，并额外处理 y_pred=1.0
        bins = np.linspace(0, 1, 11)
        df['bin'] = pd.cut(
            df['y_pred'],
            bins=bins,
            labels=False,
            include_lowest=True,
            right=False
        )
        df.loc[df['y_pred'] == 1.0, 'bin'] = 9

        # 统计 bin-wise acc/conf/count（仅均衡集）
        bin_stats = df.groupby('bin').agg(
            acc=('y', 'mean'),
            conf=('y_pred', 'mean'),
            count=('y', 'size')
        ).reset_index()
        print("\n[Balanced-Train Only] Bin-wise results:")
        print(bin_stats)

        # 计算 ECE（仅均衡集）
        ece = 0.0
        n_total = len(df)
        for _, row in bin_stats.iterrows():
            if row['count'] > 0:
                ece += abs(row['acc'] - row['conf']) * (row['count'] / n_total)
        print(f"[Balanced-Train Only] ECE: {ece:.4f}")

        # 分配软目标：每个样本的 soft_conf 取其 bin 的 acc（仅均衡集）
        acc_map = df.groupby('bin')['y'].mean().reset_index().rename(columns={'y': 'acc'})
        df = df.merge(acc_map, on='bin', how='left')

        # soft_conf_labels: 仅包含均衡样本的 id -> bin-acc
        soft_conf_labels = {row['id']: row['acc'] for _, row in df.iterrows()}

        # ===== 更新均衡集合的样本结构（只写回本轮均衡样本） =====
        new_dataset = []
        for _, row in df.iterrows():
            qid = row['id']
            conf = float(row['y_pred'])
            bin_id = int(row['bin'])
            acc_val = float(row['acc'])
            d = id2data[qid]
            new_dataset.append({
                'id': d['id'],
                'question': d['question'],
                'model_answer': d['model_answer'],
                'correctness': d['correctness'],
                'hd_label': conf,          # 存 float 即可
                'hd_target': acc_val,      # 该 bin 内的平均准确率
                'hd_bin': bin_id
            })

        new_dataset = pd.DataFrame(new_dataset).dropna(subset=['hd_target'])
        # 转字典列表作为最终可保存的数据（均衡样本集合）
        dataset_balanced_records = new_dataset.to_dict(orient='records')

    # ===== 保存“仅均衡样本”的最终数据 =====
    save_path = os.path.join(args.dataset_path, args.dataset, model_name, f'hd_data_em_balanced.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(dataset_balanced_records, f)
    print(f"Saved balanced-train EM dataset to: {save_path}")
    print(f"Num records: {len(dataset_balanced_records)}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=MODEL_IDENTIFIER)
    parser.add_argument('--classifier', type=str, default=CLASS_MODEL_IDENTIFIER)
    parser.add_argument('--dataset_path', type=str, default=DATASET_PATH)
    parser.add_argument('--dataset', type=str, default='sciq_brief')
    # hyperparameters
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--warmup_fraction', type=float, default=0.1)
    parser.add_argument('--epoch', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--weight_decay', type=int, default=1e-4)
    parser.add_argument('--max_grad_norm', type=float, default=10.0)
    args = parser.parse_args()

    # 这里 num_em_steps 可调，先跑 1 轮看看
    em_training(args, num_em_steps=3, lambda_calib=2.0) # trivalqa 2.0
