import torch
from torch.nn import functional as F
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import get_peft_model, LoraConfig, TaskType
import lightning.pytorch as pl
from typing import List
from tqdm import tqdm
from itertools import batched
import os
import re
import spacy

from .extractor import MainObjectExtractor

class LLMRefiner(object):
    def generate_and_save_to_csv(self, plain_texts: List[str], csv_file: str, batch_size: int = 32, **kwargs):
        import csv
        results = []
        for batch in tqdm(batched(plain_texts, batch_size), desc="Generating texts", ncols=80, total=(len(plain_texts) + batch_size - 1) // batch_size):
            batch_results = self.generate(batch, **kwargs)
            results.extend(batch_results)
            if os.path.exists(csv_file):
                mode = 'a'
            else:
                mode = 'w'
            if not os.path.exists(os.path.dirname(csv_file)):
                os.makedirs(os.path.dirname(csv_file))
            with open(csv_file, mode, newline='', encoding='utf-8') as f:
                csv_writer = csv.writer(f)
                if mode == 'w':
                    csv_writer.writerow(['input_text', 'refined_text'])
                for input_text, refined_text in zip(batch, batch_results):
                    csv_writer.writerow([input_text, refined_text])
        return results


class MistralRefiner(pl.LightningModule, LLMRefiner):
    def __init__(self, model_name, lr=1e-5, max_new_tokens=128, use_lora=True):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.max_new_tokens = max_new_tokens

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            load_in_8bit=True,
            device_map="auto",
            trust_remote_code=True
        )
        if use_lora:
            peft_config = LoraConfig(
                r=8,
                lora_alpha=16,
                lora_dropout=0.1,
                bias="none",
                task_type=TaskType.CAUSAL_LM,
                target_modules=["q_proj", "v_proj"]
            )
            self.model = get_peft_model(self.model, peft_config)

    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        instructions = batch['instruction']
        inputs = batch['input']
        outputs = batch['output']
        def format_prompt(input, output, instruction):
            return f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"
        prompts = [format_prompt(inp, out, inst) for inp, out, inst in zip(inputs, outputs, instructions)]
        tokenized = self.tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
        tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
        tokenized['labels'] = tokenized['input_ids'].clone()
        loss = self(**tokenized).loss
        self.log("train_loss", loss, prog_bar=True, logger=True, rank_zero_only=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr)
    
    def state_dict(self):
        # Only keep trainable parameters in the state dict
        state = super().state_dict()
        return {k: v for k, v in state.items() if v.requires_grad}
    
    def load_checkpoint(self, checkpoint_path):
        """Load a checkpoint into the model"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.load_state_dict(checkpoint['state_dict'], strict=False)
        print(f"Checkpoint loaded from {checkpoint_path}")

    @torch.no_grad()
    def generate_multi_sequence(self, inputs, temperature=0.7, instruction="Convert to short description for text-to-image generation", num_return_sequences=2, **kwargs):
        """
        Generate multiple sequences
        """
        prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{inputs}\n\n### Response:\n"

        tokenized = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        for k, v in tokenized.items():
            if isinstance(v, torch.Tensor):
                tokenized[k] = v.to(self.device)

        output_ids = self.model.generate(
            input_ids=tokenized["input_ids"],
            attention_mask=tokenized["attention_mask"],
            max_new_tokens=self.max_new_tokens,
            temperature=temperature,
            do_sample=True,
            num_return_sequences=num_return_sequences
        )

        decoded = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    # Only keep the "Response" part of the generated content (remove prompt)
        responses = []
        for text in decoded:
            gen = text.replace(prompt, "").strip()

            # Post-processing: remove explanatory part
            gen = self._postprocess_response(gen)
            responses.append(gen)

        return responses

    @torch.no_grad()
    def generate(self, inputs, temperature=0.7, instruction="Convert to short description for text-to-image generation", **kwargs):
        """
        Args:
            inputs: str or List[str]
            temperature: sampling temperature
        Returns:
            List[str] or str (same type as input)
        """
        single_input = False
        if isinstance(inputs, str):
            inputs = [inputs]
            single_input = True

        prompts = [
            f"### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Response:\n"
            for inp in inputs
        ]

        tokenized = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
        for k, v in tokenized.items():
            if isinstance(v, torch.Tensor):
                tokenized[k] = v.to(self.device)

        output_ids = self.model.generate(
            input_ids=tokenized["input_ids"],
            attention_mask=tokenized["attention_mask"],
            max_new_tokens=self.max_new_tokens,
            temperature=temperature,
            do_sample=True
        )

        decoded = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    # Only keep the "Response" part of the generated content (remove prompt)
        responses = []
        for prompt, text in zip(prompts, decoded):
            gen = text.replace(prompt, "").strip()
            
            # Post-processing: remove explanatory part
            gen = self._postprocess_response(gen)
            responses.append(gen)

        return responses[0] if single_input else responses

    def _postprocess_response(self, text):
        """
        Post-process generated text, remove explanatory parts
        """
        # Remove common explanatory prefixes
        explanatory_prefixes = [
            "Here's", "Here is", "This is", "The image shows", "I can see",
            "Based on", "According to", "In this", "This shows", "Looking at",
            "I would describe", "I can describe", "The description is",
            "A short description", "A brief description", "Description:"
        ]
        
        # Split text by lines
        lines = text.split('\n')
        processed_lines = []
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
            # Check if line starts with explanatory prefix
            is_explanatory = any(line.lower().startswith(prefix.lower()) for prefix in explanatory_prefixes)
            if not is_explanatory:
                # Remove common explanatory phrases
                cleaned_line = line
                for phrase in ["I think", "I believe", "It appears", "It seems", "Looks like"]:
                    cleaned_line = cleaned_line.replace(phrase, "").strip()
                # Remove extra punctuation
                cleaned_line = cleaned_line.strip('.,!?; ')
                if cleaned_line:
                    processed_lines.append(cleaned_line)
        # If no valid content, return first non-empty line
        if not processed_lines and lines:
            for line in lines:
                line = line.strip()
                if line:
                    return line
        # Return processed content, usually the first sentence as a short description
        result = processed_lines[0] if processed_lines else text.strip()
        # Ensure result is not too long (optional)
        if len(result) > 200:
            # Truncate at first period or at 200 characters
            if '.' in result[:200]:
                result = result[:result.find('.', 0, 200) + 1]
            else:
                result = result[:200].strip()
        return result
    
class MistralRefinerAugmentByMainObject(MistralRefiner):
    def _extract_main_objects_with_model(self, prompts):
        raise NotImplementedError("This method should be implemented in subclasses")
    
    def training_step(self, batch, batch_idx):
        instructions = batch['instruction']
        inputs = batch['input']
        outputs = batch['output']

        main_objects = self._extract_main_objects_with_model(inputs)
        def format_prompt(input, output, instruction, objects):
            return f"### Instruction:\n{instruction}\n\n### Input:\n{input}### Main Objects:\n{objects}\n\n### Response:\n{output}"

        prompts = [format_prompt(inp, out, inst, obj) for inp, out, inst, obj in zip(inputs, outputs, instructions, main_objects)]
        tokenized = self.tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
        tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
        tokenized['labels'] = tokenized['input_ids'].detach().clone()
        loss = self(**tokenized).loss
        self.log("train_loss", loss, prog_bar=True, logger=True, rank_zero_only=True)
        return loss
    
    @torch.no_grad()
    def generate(self, inputs, temperature=0.7, instruction="Convert to short description for text-to-image generation", **kwargs):
        """
        Args:
            inputs: str or List[str]
            temperature: sampling temperature
        Returns:
            List[str] or str (same type as input)
        """
        single_input = False
        if isinstance(inputs, str):
            inputs = [inputs]
            single_input = True

        main_objects = self._extract_main_objects_with_model(inputs)

        prompts = [
            f"### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Main Objects:\n{obj}\n\n### Response:\n"
            for inp, obj in zip(inputs, main_objects)
        ]

        tokenized = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
        for k, v in tokenized.items():
            if isinstance(v, torch.Tensor):
                tokenized[k] = v.to(self.device)

        output_ids = self.model.generate(
            input_ids=tokenized["input_ids"],
            attention_mask=tokenized["attention_mask"],
            max_new_tokens=self.max_new_tokens,
            temperature=temperature,
            do_sample=True
        )

        decoded = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    # Only keep the "Response" part of the generated content (remove prompt)
        responses = []
        for prompt, text in zip(prompts, decoded):
            gen = text.replace(prompt, "").strip()
            
            # Post-processing: remove explanatory part
            gen = self._postprocess_response(gen)
            responses.append(gen)

        return responses[0] if single_input else responses

class MistralRefinerwithClassname(MistralRefinerAugmentByMainObject):
    def __init__(
            self, 
            model_name, 
            extra_model_name="OpenGVLab/InternVL3-2B",
            extractor_dtype=torch.bfloat16,
            lr=0.00001, 
            max_new_tokens=128, 
            use_lora=True):
        super().__init__(model_name, lr, max_new_tokens, use_lora)
        # self.extractor = MainObjectExtractor(extra_model_name, None, extractor_dtype) 
    
    # def setup(self, stage=None):
    #     transformers.logging.set_verbosity_error()
    #     self.extractor.to(self.device)

    # def _extract_main_objects_with_model(self, images):
    #     return self.extractor(prompts=None,images=images)

    def training_step(self, batch, batch_idx):
        instructions = batch['instruction']
        inputs = batch['input']
        outputs = batch['output']
        main_objects = batch['main_object']

        def format_prompt(input, output, instruction, objects):
            return f"### Instruction:\n{instruction}\n\n### Input:\n{input}### Main Objects:\n{objects}\n\n### Response:\n{output}"

        prompts = [format_prompt(inp, out, inst, obj) for inp, out, inst, obj in zip(inputs, outputs, instructions, main_objects)]
        tokenized = self.tokenizer(prompts, return_tensors="pt", padding="max_length", truncation=True, max_length=1024)
        tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
        tokenized['labels'] = tokenized['input_ids'].detach().clone()
        loss = self(**tokenized).loss
        self.log("train_loss", loss, prog_bar=True, logger=True, rank_zero_only=True)
        return loss

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        # self.extractor.to(*args, **kwargs)
        return self
    
    @torch.no_grad()
    def generate(self, inputs, temperature=0.7, instruction="Convert to short description for text-to-image generation", **kwargs):
        """
        Args:
            inputs: str or List[str]
            temperature: sampling temperature
        Returns:
            List[str] or str (same type as input)
        """
        single_input = False
        if isinstance(inputs, str):
            inputs = [inputs]
            single_input = True

        main_objects = kwargs.get("classnames", None)
        if main_objects is None:
            raise ValueError("Please provide 'classnames' argument for main objects.")

        prompts = [
            f"### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Main Objects:\n{obj}\n\n### Response:\n"
            for inp, obj in zip(inputs, main_objects)
        ]

        tokenized = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
        for k, v in tokenized.items():
            if isinstance(v, torch.Tensor):
                tokenized[k] = v.to(self.device)

        output_ids = self.model.generate(
            input_ids=tokenized["input_ids"],
            attention_mask=tokenized["attention_mask"],
            max_new_tokens=self.max_new_tokens,
            temperature=temperature,
            do_sample=True
        )

        decoded = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    # Only keep the "Response" part of the generated content (remove prompt)
        responses = []
        for prompt, text in zip(prompts, decoded):
            gen = text.replace(prompt, "").strip()
            
            # Post-processing: remove explanatory part
            gen = self._postprocess_response(gen)
            responses.append(gen)

        return responses[0] if single_input else responses

class MistralRefinerwithMLLM(MistralRefinerAugmentByMainObject):
    def __init__(
            self, 
            model_name, 
            extra_model_name="OpenGVLab/InternVL3-2B",
            rf_model_name="XCLIU/instaflow_0_9B_from_sd_1_5",
            extractor_dtype=torch.bfloat16,
            lr=0.00001, 
            max_new_tokens=128, 
            use_lora=True):
        super().__init__(model_name, lr, max_new_tokens, use_lora)
        self.extractor = MainObjectExtractor(extra_model_name, rf_model_name, extractor_dtype)    

    def _extract_main_objects_with_model(self, prompts):
        return self.extractor(prompts)

    def setup(self, stage=None):
        transformers.logging.set_verbosity_error()
        self.extractor.to(self.device)

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.extractor.to(*args, **kwargs)
        return self

