import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List, Union
import warnings
import os
import json
import argparse
warnings.filterwarnings("ignore")

def load_data(data_dir):
    folders = [folder for folder in os.listdir(data_dir) if folder.startswith('20')]
    all_data = {}
    for folder in folders:
        path = data_dir + folder + '/' + 'introduction_results.json'
        if os.path.exists(path): 
            with open(path, 'r', encoding='utf-8') as fp:
                data = json.load(fp)
            if 'sections' in data.keys() and 'introduction' in data.keys():
                all_data[folder] = data
    return all_data

def calculate_gpt2_perplexity(text: str, model_name: str = "gpt2", max_length: int = 2048) -> float:

    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model_max_length = getattr(model.config, 'n_positions', 2048)
        effective_max_length = min(max_length, model_max_length)

        encodings = tokenizer(
            text, 
            return_tensors="pt", 
            max_length=effective_max_length,
            truncation=True,
            padding=False
        )
        
        input_ids = encodings.input_ids.to(device)
        

        if input_ids.size(1) == 0:
            return float('inf')

        if input_ids.size(1) >= effective_max_length:

        with torch.no_grad():
            model.eval()
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            perplexity = torch.exp(loss)
        
        return perplexity.item()
        
    except RuntimeError as e:
        if "CUDA" in str(e):
            return _calculate_perplexity_cpu(text, model_name, max_length)
        else:
            raise e
    except Exception as e:
        return float('inf')


def _calculate_perplexity_cpu(text: str, model_name: str = "/home/mczhang/zmc-dl/LLM/gpt2", max_length: int = 2048) -> float:

    try:
        device = torch.device("cpu")
        model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model_max_length = getattr(model.config, 'n_positions', 2048)
        effective_max_length = min(max_length, model_max_length)
        
        encodings = tokenizer(
            text, 
            return_tensors="pt", 
            max_length=effective_max_length,
            truncation=True,
            padding=False
        )
        
        input_ids = encodings.input_ids.to(device)
        
        if input_ids.size(1) == 0:
            return float('inf')
        
        with torch.no_grad():
            model.eval()
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            perplexity = torch.exp(loss)
        
        return perplexity.item()
        
    except Exception as e:
        return float('inf')


def calculate_sentence_similarity(sentence1: str, sentence2: str, 
                                model_name: str = "/home/mczhang/zmc-dl/LLM/NTP/SBERT") -> float:
    model = SentenceTransformer(model_name)
    
    embeddings = model.encode([sentence1, sentence2])


    similarity = np.dot(embeddings[0], embeddings[1]) / (
        np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1])
    )
    
    return float(similarity)


def batch_calculate_perplexity(texts: List[str], model_name: str = "/home/mczhang/zmc-dl/LLM/gpt2", max_length: int = 2048) -> List[float]:
    perplexities = []
    for i, text in enumerate(texts):
        ppl = calculate_gpt2_perplexity(text, model_name, max_length)
        perplexities.append(ppl)
    
    return perplexities


def batch_calculate_similarity(sentence_pairs: List[tuple], 
                             model_name: str = "/home/mczhang/zmc-dl/LLM/NTP/SBERT") -> List[float]:
    similarities = []
    for sentence1, sentence2 in sentence_pairs:
        sim = calculate_sentence_similarity(sentence1, sentence2, model_name)
        similarities.append(sim)
    
    return similarities


def calculate_long_text_perplexity(text: str, model_name: str = "/home/mczhang/zmc-dl/LLM/gpt2", 
                                 chunk_size: int = 2048, overlap: int = 50) -> float:

    try:

        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        

        tokens = tokenizer.encode(text)
        
        if len(tokens) <= chunk_size:

            return calculate_gpt2_perplexity(text, model_name, chunk_size)
        

        perplexities = []
        step = chunk_size - overlap
        
        for i in range(0, len(tokens), step):
            chunk_tokens = tokens[i:i + chunk_size]
            if len(chunk_tokens) < 10:  # 
                continue
            
            chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
            chunk_ppl = calculate_gpt2_perplexity(chunk_text, model_name, chunk_size)
            
            if chunk_ppl != float('inf'):
                perplexities.append(chunk_ppl)
        
        if not perplexities:
            return float('inf')
        
        return sum(perplexities) / len(perplexities)
        
    except Exception as e:
        return float('inf')



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--generate_type", default="stage")
    args = parser.parse_args()
    print(args.generate_type)
    ft = load_data(f'../writing_agents_results/{args.generate_type}/')
    print(len(ft))
    ppl_results = {}
    for item in ft.keys():
        introduction = ft[item]['introduction']
        ppl = calculate_gpt2_perplexity(introduction)
        print(f"{item} ppl: {ppl}")
        ppl_results[item] = ppl

    with open(f'../writing_agents_results/{args.generate_type}_ppl.json','w',encoding='utf-8') as fw:
        json.dump(ppl_results, fw, ensure_ascii=False, indent=4)



