import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import numpy as np
import tqdm
import os
import pickle
import transformers
import argparse
import torch
import pandas as pd

from typing import List
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer
sys.path.insert(0, '.')
# from model_util.base_model import BaseModel, STOP_SEQUENCES
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from constants import MODEL_IDENTIFIER, HF_DATASETS_CACHE,\
      HF_MODELS_CACHE, DATASET_PATH, STOP_SEQUENCES,\
          EMBEDDING_MODEL_IDENTIFIER, INDENTIFIER2NAME, CLASS_MODEL_IDENTIFIER

def create_classification_dataloader(
    batch_size: int,
    inputs_: List[str],
    labels: List[int],
    question_ids: List[str],
    tokenizer: AutoTokenizer,
    **tokenizer_kwargs,
) -> DataLoader:
    """
    Create dataloader for the calibration model.

    Parameters
    ----------
    batch_size: int
        Batch size for the calibration model.
    inputs_: List[str]
        List of inputs for the calibration model.
    question_ids: List[str]
        Question ids corresponding to the inputs.
    calibration_targets: Dict[str, float]
        Dictionary mapping from question ids to calibration targets.
    calibration_data: Dict[str, Dict[str, Any]]
        Dictionary mapping from question id to all the corresponding data extracted from the target LLM.
    tokenizer: AutoTokenizer
        Tokenizer for the calibration model.

    Returns
    -------
    DataLoader
        Created dataloader.
    """
    data = []

    for i, (input_, label, question_id) in enumerate(zip(inputs_, labels, question_ids)):
        tokenized_input = tokenizer(input_, **tokenizer_kwargs)
        label = label
        question_id = question_id
        data.append(
            {
                **tokenized_input,
                'question_id': question_id,
                "label": label,
            }
        )

    data_loader = DataLoader(data, batch_size=batch_size, shuffle=True)

    return data_loader

def loop_dataloader(dataloader: DataLoader):
    """
    Loop through a dataloader infinitely.

    Parameters
    ----------
    dataloader: Dataloder
        Dataloader to be looped through.

    Yields
    ------
    batch: Dict[str, Any]
        Batch from dataloader.
    """
    while True:
        for batch in dataloader:
            yield batch

def flatten(nested_list):
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):  # 如果是子列表，则递归
            flat_list.extend(flatten(item))
        else:
            flat_list.append(item)  # 如果是字典或非列表项，直接添加
    return flat_list

def k_fold(args):
    model_name = INDENTIFIER2NAME[args.model]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset_path = os.path.join(args.dataset_path, args.dataset, model_name, 'data_with_answer.pkl')

    with open(dataset_path, 'rb') as f:
        dataset = pickle.load(f)
    train_data = dataset['train']
    train_data = flatten(train_data)

    calibration_tokenizer = AutoTokenizer.from_pretrained(args.classifier)
    calibration_config = AutoConfig.from_pretrained(args.classifier)

    inputs = []
    labels = []
    question_ids = []
    for data in train_data:
        input_ = data['question'][0]
        input_ += f' [SEP] {data["model_answer"]}'
        inputs.append(input_)
        labels.append(int(data['correctness'][0]))
        question_ids.append(data['id'][0])

    k = 2
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    kf_data = kf.split(train_data)

    for train_ids, test_ids in kf_data:
        print(f'split: {10-k}')
        test_id_list = []
        classifier_model = AutoModelForSequenceClassification.from_pretrained(args.classifier).to(device)
        dataloaders = {}
        x_train, x_test = [inputs[train_id] for train_id in train_ids], \
                            [inputs[test_id] for test_id in test_ids]
        y_train, y_test = [labels[train_id] for train_id in train_ids], \
                            [labels[test_id] for test_id in test_ids]
        qid_train, qid_test = [question_ids[train_id] for train_id in train_ids], \
                            [question_ids[test_id] for test_id in test_ids]
        
        dataloaders['train'] = create_classification_dataloader(
                            batch_size=args.batch_size,
                            inputs_=x_train,
                            labels=y_train, 
                            question_ids = qid_train,
                            tokenizer=calibration_tokenizer,
                            # Tokenizer kwargs
                            padding="max_length",
                            truncation=True,
                            max_length=calibration_config.max_position_embeddings,
                            return_tensors="pt",
                        )
        dataloaders['test'] = create_classification_dataloader(
                                batch_size=args.batch_size,
                                inputs_=x_test,
                                labels=y_test, # calibration_data with all information needed
                                question_ids = qid_test,
                                tokenizer=calibration_tokenizer,
                                # Tokenizer kwargs
                                padding="max_length",
                                truncation=True,
                                max_length=calibration_config.max_position_embeddings,
                                return_tensors="pt",
                            )
        
        pos = sum(y_train)
        neg = len(y_train) - pos
        print("pos", pos, neg)
        eps = 1e-8
        # one-hot 二分类下，BCEWithLogitsLoss 的 pos_weight 逐列指定“正项系数”
        # 第0列对应 y=0，正样本数约等于 neg；第1列对应 y=1，正样本数约等于 pos
        pos_weight = torch.tensor([pos / (neg + eps),  # 列0
                                neg / (pos + eps)], # 列1
                                dtype=torch.float32, device=device)
        bce_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        ###### Train ######
        iterations = args.epoch * len(dataloaders['train'])
        optimizer = optim.AdamW(
                list(classifier_model.parameters()),
                lr=args.lr,
                weight_decay=args.weight_decay,
                )
        lr_scheduler = transformers.get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=min(int(args.warmup_fraction * iterations), 100),
            num_training_steps=int(
                iterations * 1.1
            ),  # This makes sure that the final LR isn't just 0
            )
        for i, batch in enumerate(loop_dataloader(dataloaders['train'])):
            if i >= iterations:
                break
            input_ids = batch["input_ids"].squeeze(1).to(device)
            attention_mask = batch["attention_mask"].squeeze(1).to(device)
            gt = batch["label"].to(device)

            outputs = classifier_model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits                         # [B, 2]
            label = F.one_hot(gt, num_classes=2).float().to(device)  # [B, 2]
            loss = bce_criterion(logits, label)
            loss.backward()
            clip_grad_norm_(classifier_model.parameters(), max_norm=args.max_grad_norm)
            print(
                f"[Step {i+1}/{iterations}] Loss: {loss.detach().cpu().item():.4f}"
                )
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        print('Finished training!!!')
        ##########################

        print('Get hardness labels')
        ###### Test ######
        with torch.no_grad():
            all_preds = []
            all_gt = []
            for i, batch in enumerate(tqdm.tqdm(dataloaders['test'])):
                input_ids = batch["input_ids"].squeeze(1).to(device)
                attention_mask = batch["attention_mask"].squeeze(1).to(device)
                test_qids = batch['question_id']
                gt = batch["label"]

                outputs = classifier_model(input_ids, attention_mask=attention_mask)
                pred_label = torch.argmax(outputs.logits, dim=-1)
                preds = F.softmax(outputs.logits, dim=-1)
                all_gt += gt.cpu().tolist()
                all_preds += pred_label.cpu().tolist()
                results = []
                count = 0
                for (qid, pred, label) in zip(test_qids, preds, gt):
                    for item in train_data:
                        if item.get('id')[0] == qid:
                            item.update({'hd_label': pred[1]}) 
                            break
                ### Add acc ######
            acc = accuracy_score(all_gt, all_preds)
            print(f'Finished split {10-k} \n Split Accuracy: {acc}')
            test_id_list.extend(test_ids.tolist())
            
        del classifier_model
        k -= 1
        print(f'total number of test ids: {len(test_id_list)}')

    ############## Group the results ################
    all_data, all_hd_labels, all_correctness, all_ids = [],[],[],[]
    count = 0
    for item in train_data:
        try:
            all_data.append((item['hd_label'], item['correctness'], item['id']))
            all_hd_labels.append(item['hd_label'].to('cpu').item())
            all_correctness.append(int(item['correctness'][0]))
            all_ids.append(item['id'][0])
        except:
            print(item)
            train_data.remove(item)
            count += 1
    print(f"number of invalid ids: {count}")

    bins = np.arange(0, 1, 0.1)
    bins_per_prediction = np.digitize(all_hd_labels, bins)
    df = pd.DataFrame(
        {
            "y_pred": all_hd_labels,
            "y": all_correctness,
            'id': all_ids,
            "pred_bins": bins_per_prediction,

        }
    )

    # 计算每个bin的平均准确率和平均预测置信度
    bin_stats = df.groupby('pred_bins').agg(
        acc=('y', 'mean'),
        conf=('y_pred', 'mean'),
        count=('y', 'size')
    ).reset_index()

    print("\nBin-wise results:")
    print(bin_stats)

    # ====== 计算ECE ======
    ece = 0.0
    n = len(df)
    for _, row in bin_stats.iterrows():
        bin_size = row['count']
        if bin_size > 0:
            ece += abs(row['acc'] - row['conf']) * (bin_size / n)

    print(f"\nExpected Calibration Error (ECE): {ece:.4f}")


    acc = df.groupby('pred_bins')['y'].mean().reset_index()
    acc.rename(columns={'y': 'acc'}, inplace=True)
    df = df.merge(acc, on='pred_bins', how='left')
    list_of_dicts = df.to_dict(orient='records')
    acc_list = [item['acc'] for item in list_of_dicts]
    id_list = [item['id'] for item in list_of_dicts]
    bins_list = [item['pred_bins'] for item in list_of_dicts]

    #### Do this because there is duplicated data #############
    new_dataset = []
    for i, (acc, id, bin) in tqdm.tqdm(enumerate(zip(acc_list, id_list, bins_list))):
        for d in train_data:
            if d['id'][0] == id:
                d.update({'hd_target':acc})
                d.update({'hd_bin':bin})
                new_dataset.append(d)

    new_dataset = pd.DataFrame(new_dataset)
    new_dataset = new_dataset.dropna(subset=['hd_target'])
    dataset = new_dataset.to_dict(orient='records')
    print(f'number of items in dataset: {len(dataset)}')  
    save_path = os.path.join(args.dataset_path, args.dataset, model_name, f'hd_data_{args.epoch}.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(dataset, f)
            
    print('Done')

def map2acc(args):
    model_name = INDENTIFIER2NAME[args.model]
    path = os.path.join(args.dataset_path, args.dataset, model_name, 'hd_data.pkl')
    with open(path, 'rb') as f:
        dataset = pickle.load(f)

    all_data, all_hd_labels, all_correctness, all_ids = [],[],[],[]
    count = 0
    for item in dataset:
        try:
            all_data.append((item['hd_label'], item['correctness'], item['id']))
            all_hd_labels.append(item['hd_label'].to('cpu').item())
            all_correctness.append(int(item['correctness'][0]))
            all_ids.append(item['id'][0])
        except:
            print(item)
            dataset.remove(item)
            count += 1
    print(f"number of invalid ids: {count}")

    bins = np.arange(0, 1, 0.1)
    bins_per_prediction = np.digitize(all_hd_labels, bins)
    df = pd.DataFrame(
        {
            "y_pred": all_hd_labels,
            "y": all_correctness,
            'id': all_ids,
            "pred_bins": bins_per_prediction,

        }
    )
    acc = df.groupby('pred_bins')['y'].mean().reset_index()
    acc.rename(columns={'y': 'acc'}, inplace=True)
    df = df.merge(acc, on='pred_bins', how='left')
    list_of_dicts = df.to_dict(orient='records')
    acc_list = [item['acc'] for item in list_of_dicts]
    id_list = [item['id'] for item in list_of_dicts]
    bins_list = [item['pred_bins'] for item in list_of_dicts]

    #### Do this because there is duplicated data #############
    new_dataset = []
    for i, (acc, id, bin) in tqdm.tqdm(enumerate(zip(acc_list, id_list, bins_list))):
        for d in dataset:
            if d['id'][0] == id:
                d.update({'hd_target':acc})
                d.update({'hd_bin':bin})
                new_dataset.append(d)

    new_dataset = pd.DataFrame(new_dataset)
    new_dataset = new_dataset.dropna(subset=['hd_target'], inplace=True)
    dataset = new_dataset.to_dict(orient='records')
    print(f'number of items in dataset: {len(dataset)}')  
    save_path = os.path.join(args.dataset_path, args.dataset, model_name, 'hd_data_updated.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(dataset, f)

    print("done")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=MODEL_IDENTIFIER)
    parser.add_argument('--classifier', type=str, default=CLASS_MODEL_IDENTIFIER)
    parser.add_argument('--dataset_path', type=str, default=DATASET_PATH)
    # parser.add_argument('--cali_dataset_path', type=str, default='/data/DERI-Gong/jh015/grace_codes/data/sciq_brief/llama2-7b/data_with_answer.pkl')
    parser.add_argument('--dataset', type=str, default='triviaqa_brief')
    #hyperparameters
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--warmup_fraction', type=float, default=0.1)
    parser.add_argument('--epoch', type=int, default=16)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--weight_decay', type=int, default=1e-4)
    parser.add_argument('--max_grad_norm', type=int, default=10)
    args = parser.parse_args()

    k_fold(args)
    # map2acc(args)


