from typing import Union, List
import torch
import os
from itertools import batched
from tqdm import tqdm
from torch import nn
from lightning.pytorch import LightningModule
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

CACHE_FILE = os.path.join(os.path.dirname(__file__), "../../../data/refiner_cache/cache.tsv")

def load_prompt_refiner(pretrained_model_name: str = "microsoft/Promptist"):
  prompter_model = GPT2LMHeadModel.from_pretrained(pretrained_model_name)
  tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = "left"
  return prompter_model, tokenizer

def load_cache(cache_file: str = CACHE_FILE):
    if not os.path.exists(cache_file):
        return {}
    
    cache = {}
    with open(cache_file, 'r') as f:
        for line in f:
            key, value = line.strip().split('\t')
            cache[key] = value
    return cache

class RefinerModel(nn.Module):
    def __init__(self, pretrained_model_name: str):
        super(RefinerModel, self).__init__()
        self.model, self.tokenizer = load_prompt_refiner(pretrained_model_name)

    def generate(self, plain_text: Union[str, List[str]], **kwargs) -> Union[str, List[str]]:
        device = self.model.device
        # Handle single text or list of texts
        if isinstance(plain_text, str):
            plain_texts = [plain_text]
            single_input = True
        else:
            plain_texts = plain_text
            single_input = False
        
        # Add "Rephrase:" suffix to all texts
        input_texts = []
        for text in plain_texts:
            if len(text.strip()) >= 2000:
                if '. ' in text:
                    text = text.split('. ')[0]
                    text = text + '.'
                elif '! ' in text:
                    text = text.split('! ')[0]
                    text = text + '!'
                elif ', ' in text:
                    text = text.split(', ')[0]
                    text = text + '.'
                else:
                    text = text[:500]
            if len(text.strip()) >= 2000:
                text = text[:500]
            input_texts.append(text.strip() + " Rephrase:")

        # Batch tokenize
        input_ids = self.tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).input_ids
        eos_id = self.tokenizer.eos_token_id

        # Move input data to GPU
        input_ids = input_ids.to(device)

        # Batch generation
        outputs = self.model.generate(
            input_ids, 
            do_sample=False, 
            max_new_tokens=75, 
            num_beams=8, 
            num_return_sequences=8, 
            eos_token_id=eos_id, 
            pad_token_id=eos_id, 
            length_penalty=-1.0
        )
        
        # Decode output
        output_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        # Process generated results
        results = []
        for i, original_text in enumerate(plain_texts):
            # Each input corresponds to 8 outputs (num_return_sequences=8)
            batch_outputs = output_texts[i*8:(i+1)*8]
            # Take the first result and clean it
            res = ''
            j = 0
            while res == '':
                res = batch_outputs[j].replace(original_text + " Rephrase:", "").strip()
                j += 1
                if j >= len(batch_outputs):
                    raise ValueError(f"No valid output found for input: {original_text}")
            
            results.append(res)
        
        # If input is a single string, return a single result
        return results[0] if single_input else results
    
    def generate_and_save_to_csv(self, plain_texts: List[str], csv_file: str, batch_size: int = 32):
        import csv
        results = []
        for batch in tqdm(batched(plain_texts, batch_size), desc="Generating texts", ncols=80, total=(len(plain_texts) + batch_size - 1) // batch_size):
            batch_results = self.generate(batch)
            results.extend(batch_results)
            if os.path.exists(csv_file):
                mode = 'a'
            else:
                mode = 'w'
            if not os.path.exists(os.path.dirname(csv_file)):
                os.makedirs(os.path.dirname(csv_file))
            with open(csv_file, mode, newline='', encoding='utf-8') as f:
                csv_writer = csv.writer(f)
                if mode == 'w':
                    csv_writer.writerow(['input_text', 'refined_text'])
                for input_text, refined_text in zip(batch, batch_results):
                    csv_writer.writerow([input_text, refined_text])
        return results
        

    
