import os, json, random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List
import argparse
from scipy.stats import spearmanr, pearsonr
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from datasets import load_dataset
import tempfile
from model import FactualityChecker
from collections import defaultdict

import sys

import wandb

sys.path.append("./qags/")
from datasets import concatenate_datasets
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
device = "cuda" if torch.cuda.is_available() else "cpu"

def eval_bertscore(df, batch_size=16, device="cuda"):
    from bert_score import score as bert_score

    P, R, F1 = bert_score(
        df["summary"].tolist(),
        df["document"].tolist(),
        lang="en",
        batch_size=batch_size,
        device=device,
        verbose=False,
    )

    human = df["human"].to_numpy(dtype=float)
    metric = {}

    metric["spearman"] = spearmanr(F1, human).correlation
    metric["pearson"]  = pearsonr (F1, human).correlation

    # if labels are binary we can threshold BERTScore too
    if set(np.unique(human)) <= {0, 1}:
        probs   = F1  # already in [0,1]
        y_pred  = (probs >= 0.5).int()
        metric["acc"]   = accuracy_score(human, y_pred)
        metric["f1"]    = f1_score      (human, y_pred)
        metric["auc"]   = roc_auc_score (human, probs)

    return metric

def eval_factcc(df):
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    import torch
    from sklearn.metrics import f1_score, accuracy_score

    # Load the FactCC model and tokenizer
    model_name = "manueldeprada/FactCC"
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Prepare inputs
    documents = list(df["document"])
    summaries = list(df['summary'])

    # Predict
    preds = []
    for doc, summ in zip(documents, summaries):
        inputs = tokenizer(doc, summ, return_tensors="pt", truncation=True).to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        pred = torch.argmax(logits).item()
        preds.append(pred)  # 0 = SUPPORTED, 1 = NOT_SUPPORTED

    # y_pred = [0 if p == 0 else 1 for p in preds]
    y_pred = [1 if p == 0 else 0 for p in preds]
    y_true = df["human"].tolist()

    # Handle binary vs continuous
    is_binary = set(y_true) <= {0, 1}

    metrics = {
        "spearman": spearmanr(y_pred, y_true).correlation,
        "pearson": pearsonr(y_pred, y_true).correlation,
    }

    if is_binary:
        metrics["f1"] = f1_score(y_true, y_pred)
        metrics["acc"] = accuracy_score(y_true, y_pred)

    return metrics

def eval_qags(df):
    from qags.qa_utils import get_qags_scores
    
    src_data = {str(i): doc for i, doc in enumerate(df['document'])}
    trg_data = {str(i): summary for i, summary in enumerate(df['summary'])}
    
    # Create temporary JSON files
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as src_file:
        json.dump(src_data, src_file)
        src_path = src_file.name
    
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as trg_file:
        json.dump(trg_data, trg_file)
        trg_path = trg_file.name
    
    # Get QAGS scores
    qags_scores = get_qags_scores(src_path, trg_path,n_qsts_per_doc=1)
    print(qags_scores)
    # Remove temporary files
    os.remove(src_path)
    os.remove(trg_path)
    
    return spearmanr(qags_scores, df['human'])[0], pearsonr(qags_scores, df['human'])[0]
def load_qags(jsonl_paths: List[str]):
    records = []
    for path in jsonl_paths:
        print(path)
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                
                obj = json.loads(line)
                article = obj.get("article","")
                summary_sentences = obj.get("summary_sentences","")
                for s in summary_sentences:
                    sentence = s.get("sentence","").strip()
                    responses = s.get('responses',[])
                    yes_votes = sum(1 for r in responses if r['response'].lower()=='yes')                
                    majority = 1 if yes_votes >= 2 else 0

            
                    records.append({
                        "document": article,
                        "summary": sentence,
                        "human": majority
                    })
    return pd.DataFrame(records)

def load_factcc():
    # Load and concatenate validation + test
    val_ds = load_dataset("mtc/factcc_annotated_eval_data", split="validation")
    test_ds = load_dataset("mtc/factcc_annotated_eval_data", split="test")
    ds = concatenate_datasets([val_ds, test_ds])

    records = []
    for ex in ds:
        doc     = ex['text']
        summary = ex['claim']
        label   = 1 if ex['label'] == "CORRECT" else 0
        records.append({
            'document': doc,
            'summary': summary,
            'human': label
        })

    df = pd.DataFrame(records)
    return df


def load_summeval():
    ds = load_dataset("mteb/summeval", split="test")
    records = []
    for ex in ds:
        doc = ex['text']
        summaries = ex['machine_summaries']
        scores = ex['consistency']
        for summary, score in zip(summaries, scores):
            records.append({
                'document':doc, "summary": summary, "human":float(score)
            })
    df = pd.DataFrame(records)
    return df

def encode_sliding(model, tokenizer, text, device, chunk_size=512, stride=256):
    toks = tokenizer(text, return_tensors='pt', truncation=False, return_offsets_mapping=True)
    ids, mask = toks['input_ids'][0], toks['attention_mask'][0]
    token_embeds = defaultdict(list)

    

    with torch.no_grad():
        for start in range(0, len(ids), stride):
            end = start + chunk_size
            chunk_ids, chunk_mask = ids[start:end], mask[start:end]

            if len(chunk_ids) < chunk_size:
                pad = chunk_size - len(chunk_ids)
                chunk_ids  = F.pad(chunk_ids,  (0, pad), value=tokenizer.pad_token_id)
                chunk_mask = F.pad(chunk_mask, (0, pad), value=0)

            out = model(
                input_ids=chunk_ids.unsqueeze(0).to(device),
                attention_mask=chunk_mask.unsqueeze(0).to(device),
                output_hidden_states=True,
                return_dict=True
            )
            hidden = out.hidden_states[-2].squeeze(0).detach().cpu()  
            

            for i in range(chunk_size):
                g_idx = start + i
                if g_idx >= len(ids): break
                token_embeds[g_idx].append(hidden[i])

    seq = [torch.stack(v).mean(0) for _, v in sorted(token_embeds.items())]
    return torch.stack(seq) if seq else torch.zeros((1, model.config.hidden_size), device=device)

class TextPairDS(Dataset):
    def __init__(self, docs, sums): self.docs, self.sums = docs, sums
    def __len__(self): return len(self.docs)
    def __getitem__(self, i): return self.docs[i], self.sums[i]

def train_and_eval_multi(
    args,
    eval_dict,                       # {"QAGS": qags_df, "SummEval": sum_df}
    device="cuda",
    encoder_name="bert-base-uncased",
    epochs=100, batch_size=64,
    chunk_size=512, stride=256,
    eval_every=200
):

    model     = FactualityChecker(args, device=device).to(device)
    optim     = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_ds  = load_dataset("therapara/summary-of-news-articles_new", split="train")
    loader    = DataLoader(TextPairDS(train_ds['article'], train_ds['highlights']),
                           batch_size=batch_size, shuffle=True)

    wandb.init(project="FactualConsistency", config=vars(args), name="SATPool-FC")


    def eval_all(step: int):
        model.eval()

        with torch.no_grad():
            for tag, df in eval_dict.items():
                probs = []

                for d_txt, s_txt in zip(df["document"], df["summary"]):
                    d_tok = encode_sliding(model.text_encoder, model.tokenizer, d_txt, device, chunk_size, stride)
                    s_tok = encode_sliding(model.text_encoder, model.tokenizer, s_txt, device, chunk_size, stride)

                    d_vec = model.encode_documents({0: {"embedding": d_tok}}, [0])[0]
                    s_vec = model.encode_summary({0: {"embedding": s_tok}}, [0])[0]

                    cosine = F.cosine_similarity(d_vec, s_vec, dim=0, eps=1e-8)
                    logit  = model.cls(cosine.unsqueeze(0)).squeeze(0)
                    probs.append(torch.sigmoid(logit).item())

                # Get ground-truth
                y_true = df["human"].tolist()

                # Compute metrics
                metrics = {}
                is_binary = set(y_true) <= {0, 1}
                y_pred = [int(p >= 0.5) for p in probs]
                if is_binary:
                    y_pred = [int(p >= 0.5) for p in probs]
                    f1 = f1_score(y_true, y_pred)
                    metrics[f"{tag}/f1"] = f1

                # Always compute correlation regardless of type
                rho = spearmanr(probs, y_true).correlation
                r   = pearsonr(probs, y_true).correlation
                metrics[f"{tag}/spearman"] = rho
                metrics[f"{tag}/pearson"] = r

                rho_binary = spearmanr(y_pred, y_true).correlation
                r_binary   = pearsonr(y_pred, y_true).correlation
                metrics[f"{tag}/spearman_binary"] = rho_binary
                metrics[f"{tag}/pearson_binary"] = r_binary
                metrics["step"] = step

                # Logging
                wandb.log(metrics)
                log_str = f"[step {step}] {tag}: ρ={rho:.4f}  r={r:.4f}, bianry  ρ={rho_binary:.4f}  r={r_binary:.4f}, "
                if is_binary:
                    log_str += f"  F1={f1:.4f}"
                print(log_str)

        model.train()

    criterion = nn.BCEWithLogitsLoss()
    
    global_step = 0
    eval_all(global_step)
    for ep in range(1, epochs+1):
        model.train(); epoch_loss = 0
        for docs, sums in tqdm(loader, desc=f"Epoch {ep}"):

            pos_logits, neg_logits = [], []

            # ── positive logits ───────────────────────────
            for d_txt, s_txt in zip(docs, sums):
                d_emb = encode_sliding(model.text_encoder, model.tokenizer,
                                    d_txt, device, chunk_size, stride)
                s_emb = encode_sliding(model.text_encoder, model.tokenizer,
                                    s_txt, device, chunk_size, stride)

                
                d_vec = model.encode_documents({0: {"embedding": d_emb}}, [0])[0]
                s_vec = model.encode_summary ({0: {"embedding": s_emb}}, [0])[0]
                cos   = F.cosine_similarity(d_vec, s_vec, dim=0, eps=1e-8)
                logit = model.cls(cos.unsqueeze(0)).squeeze(0)   # scalar
                pos_logits.append(logit)

            # ── negative logits (batch roll) ───────────────
            shuffled_sums = sums[1:] + sums[:1]
            for d_txt, wrong_sum in zip(docs, shuffled_sums):
                d_emb = encode_sliding(model.text_encoder, model.tokenizer,
                                    d_txt, device, chunk_size, stride)
                w_emb = encode_sliding(model.text_encoder, model.tokenizer,
                                    wrong_sum, device, chunk_size, stride)

                d_vec = model.encode_documents({0: {"embedding": d_emb}}, [0])[0]
                w_vec = model.encode_summary ({0: {"embedding": w_emb}}, [0])[0]
                cos   = F.cosine_similarity(d_vec, w_vec, dim=0, eps=1e-8)
                logit = model.cls(cos.unsqueeze(0)).squeeze(0)
                neg_logits.append(logit)

            # ── BCE loss ──────────────────────────────────
            logits = torch.cat([torch.stack(pos_logits),
                                torch.stack(neg_logits)])            # [2B]
            labels = torch.cat([torch.ones(len(pos_logits), device=device),
                                torch.zeros(len(neg_logits), device=device)])

            loss = criterion(logits, labels)

            optim.zero_grad()
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            global_step += 1
            if global_step % eval_every ==0:
                eval_all(global_step)

        avg = epoch_loss / len(loader)
        wandb.log({"epoch": ep, "train_loss": avg, "step": global_step})
        print(f"Epoch {ep}  avg-loss = {avg:.4f}")

    wandb.finish()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="sat")
    parser.add_argument("--gate", type=str, default="none")
    parser.add_argument('--K', type=int, default=4, help='Number of frequency bins for STC')
    parser.add_argument("--proj_dim", type=int, default=512)
    args, _ = parser.parse_known_args()
    def set_seed(seed: int = 42):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    set_seed(42)
    def run_metric(metric_fn, df, allow_f1):
        out = metric_fn(df)                          
        keys = ["spearman", "pearson"] + (["f1"] if allow_f1 else [])
        return {k: out[k] for k in keys}

    def log_metric(prefix, d):
        for k, v in d.items():
            print(f"{prefix}/{k:<8s}: {v:.4f}")
    eval_dict = {
           "SummEval": load_summeval(),                                                    # continuous
        "FactCC":   load_factcc(),                                                       # binary
        "QAGS":     load_qags(["qags/data/mturk_cnndm.jsonl","qags/data/mturk_xsum.jsonl"]),  # binary
     
    }
    print({k: v.shape for k, v in eval_dict.items()})

    train_and_eval_multi(args, eval_dict, device=device)
    
if __name__ == "__main__":
    main()

