import torch
import torch.nn as nn
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, AutoModelForCausalLM
from transformers import BloomTokenizerFast, BloomForCausalLM, BloomConfig, BloomModel
from transformers import LlamaForCausalLM, LlamaConfig
from peft import LoraConfig, get_peft_model, TaskType
from transformers import DebertaV2Model, DebertaV2Config


# Imports for Qwen2, Llama3
from transformers import Qwen2ForCausalLM, Qwen3ForCausalLM


# models.py
import torch
import torch.nn as nn


# --- CLASSES FOR FINE-TUNING TOP LAYERS ---

class FinetuneLlama3Coder(nn.Module):
    def __init__(self, model_name="meta-llama/Llama-3.2-1B", num_layers_to_finetune=8):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        self.model.gradient_checkpointing_enable()

        if num_layers_to_finetune != 0:
            # 1. Freeze all parameters initially
            for param in self.model.parameters():
                param.requires_grad = False

            # 2. Unfreeze the top N layers
            # Llama's layers are in `model.layers`
            num_layers = len(self.model.model.layers)
            for i in range(num_layers - num_layers_to_finetune, num_layers):
                for param in self.model.model.layers[i].parameters():
                    param.requires_grad = True

        # 3. Replace the head and ensure it's trainable
        hidden_size = self.model.config.hidden_size
        self.model.lm_head = nn.Linear(hidden_size, 1, bias=False)
        for param in self.model.lm_head.parameters():
            param.requires_grad = True
            
    def forward(self, input_ids, attention_mask=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]
        return logits


class FinetuneQwen3Coder(nn.Module):
    def __init__(self, model_name="Qwen/Qwen3-1.7B", num_layers_to_finetune=8):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        self.model.gradient_checkpointing_enable()
        if num_layers_to_finetune != 0:
            # 1. Freeze all parameters
            for param in self.model.parameters():
                param.requires_grad = False
                
            # 2. Unfreeze top N layers (Qwen3 also uses `model.layers`)
            num_layers = len(self.model.model.layers)
            for i in range(num_layers - num_layers_to_finetune, num_layers):
                for param in self.model.model.layers[i].parameters():
                    param.requires_grad = True
                
        # 3. Replace the head
        hidden_size = self.model.config.hidden_size
        self.model.lm_head = nn.Linear(hidden_size, 1, bias=False)
        for param in self.model.lm_head.parameters():
            param.requires_grad = True
            
    def forward(self, input_ids, attention_mask=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]
        return logits


class FinetuneDeepSeekCoder(nn.Module):
    def __init__(self, model_name="deepseek-ai/deepseek-coder-1.3b-base", num_layers_to_finetune=2):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        self.model.gradient_checkpointing_enable()
        
        if num_layers_to_finetune != 0:
            # 1. Freeze all parameters
            for param in self.model.parameters():
                param.requires_grad = False
                
            # 2. Unfreeze top N layers (DeepSeek also uses `model.layers`)
            num_layers = len(self.model.model.layers)
            for i in range(num_layers - num_layers_to_finetune, num_layers):
                for param in self.model.model.layers[i].parameters():
                    param.requires_grad = True
        
        # 3. Replace the head
        hidden_size = self.model.config.hidden_size
        self.model.lm_head = nn.Linear(hidden_size, 1, bias=False)
        for param in self.model.lm_head.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]
        return logits


# -----------------------------------------------------------------
# Training from scratch
# -----------------------------------------------------------------

class Qwen3Coder(nn.Module):
    def __init__(self, model_name, config_dict):
        super(Qwen3Coder, self).__init__()

        # Start from a prebuilt config and override with your settings
        config = AutoConfig.from_pretrained(model_name)
        config.vocab_size = config_dict["vocab_size"]
        config.max_position_embeddings = config_dict["max_position_embeddings"]
        config.bos_token_id = config_dict["bos_token_id"]
        config.use_cache = False

        # Initialize model from scratch using the modified config
        self.model = Qwen3ForCausalLM(config)
        self.model.gradient_checkpointing_enable()

        # Replace the LM head for binary classification
        self.model.lm_head = nn.Linear(config.hidden_size, 1, bias=True)

        # Debug info
        filename = f"debug_{self.__class__.__name__}_output.txt"
        print(f"Writing debug info to {filename}")
        with open(filename, "w") as file:
            file.write("--- Model Config ---\n")
            file.write(str(config))
            file.write("\n\n--- Named Modules ---\n")
            for name, module in self.model.named_modules():
                file.write(f"{name}: {module.__class__.__name__}\n")
            file.write("\n\n--- Trainable Parameters ---\n")
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    file.write(f"{name}\n")

    def forward(self, x, attention_mask=None):
        output = self.model(input_ids=x, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]  # last token's logit

        return logits  # shape: (batch_size, 1)


class Llama3Coder(nn.Module):
    def __init__(self, model_name, config_dict, model_config_overrides):
        super(Llama3Coder, self).__init__()

        # Load the base Llama3 configuration blueprint
        config = AutoConfig.from_pretrained(model_name)

        # Override with your custom settings
        config.vocab_size = config_dict["vocab_size"]
        config.max_position_embeddings = config_dict["max_position_embeddings"]
        config.bos_token_id = config_dict["bos_token_id"]
        config.use_cache = False

        # Override architectural settings to create a ~1B model
        if model_config_overrides:
            for key, value in model_config_overrides.items():
                setattr(config, key, value)
        
        # Initialize the LlamaForCausalLM model from scratch with this config
        self.model = LlamaForCausalLM(config)
        self.model.gradient_checkpointing_enable()

        # Replace the LM head for binary classification
        self.model.lm_head = nn.Linear(config.hidden_size, 1, bias=False)

        # Debug info
        filename = f"debug_{self.__class__.__name__}_output.txt"
        print(f"Writing debug info to {filename}")
        with open(filename, "w") as file:
            file.write("--- Model Config ---\n")
            file.write(str(config))
            file.write("\n\n--- Named Modules ---\n")
            for name, module in self.model.named_modules():
                file.write(f"{name}: {module.__class__.__name__}\n")
            file.write("\n\n--- Trainable Parameters ---\n")
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    file.write(f"{name}\n")
    
    def forward(self, x, attention_mask=None):
        output = self.model(input_ids=x, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]
        return logits


class DeepSeekCoder(nn.Module):
    def __init__(self, model_name, config_dict):
        super(DeepSeekCoder, self).__init__()

        # Load the base DeepSeek configuration blueprint
        config = AutoConfig.from_pretrained(model_name)

        # Override with your custom settings
        config.vocab_size = config_dict["vocab_size"]
        config.max_position_embeddings = config_dict["max_position_embeddings"]
        config.bos_token_id = config_dict["bos_token_id"]
        config.use_cache = False
        
        # Use the generic AutoModelForCausalLM.from_config.
        # This will automatically find the right architecture (like DeepseekForCausalLM)
        # and initialize it from scratch using our custom config object.
        self.model = AutoModelForCausalLM.from_config(config)
        self.model.gradient_checkpointing_enable()

        # Replace the LM head for binary classification
        self.model.lm_head = nn.Linear(config.hidden_size, 1, bias=False)

        # Debug info
        filename = f"debug_{self.__class__.__name__}_output.txt"
        print(f"Writing debug info to {filename}")
        with open(filename, "w") as file:
            file.write("--- Model Config ---\n")
            file.write(str(config))
            file.write("\n\n--- Named Modules ---\n")
            for name, module in self.model.named_modules():
                file.write(f"{name}: {module.__class__.__name__}\n")
            file.write("\n\n--- Trainable Parameters ---\n")
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    file.write(f"{name}\n")
    
    def forward(self, x, attention_mask=None):
        output = self.model(input_ids=x, attention_mask=attention_mask)
        logits = output.logits
        if len(logits.shape) == 3:
            logits = logits[:, -1, :]
        return logits


class EnsembleGPT2(nn.Module):
    def __init__(self, config, num_models, ensemble_func):
        super(EnsembleGPT2, self).__init__()
        self.models = nn.ModuleList([GPT2LMHeadModel(config) for _ in range(num_models)])
        self.ensemble_func = ensemble_func
    def forward(self, x, attention_mask=None):
        outputs = []
        for model in self.models:
            with torch.backends.cuda.sdp_kernel(enable_flash=True):
                output = model(x, attention_mask=attention_mask)
            outputs.append(output.logits)
        if self.ensemble_func == 'mean':
            avg_logits = torch.mean(torch.stack(outputs), dim=0)
        return avg_logits


class EnsembleBLOOM(nn.Module):
    def __init__(self, config, num_models, ensemble_func):
        super(EnsembleBLOOM, self).__init__()
        # self.models = nn.ModuleList([BloomForCausalLM(config) for _ in range(num_models)])
        self.models = nn.ModuleList([BloomModel(config) for _ in range(num_models)])
        self.ensemble_func = ensemble_func
        self.classifier = nn.Linear(config.hidden_size, 1)  # 2-class head
        self.print_model_details()
    def print_model_details(self):
        print(f"Model Type: {type(self.models[0])}")
        print(f"Number of Models in Ensemble: {len(self.models)}")
    def forward(self, x, attention_mask=None):
        last_hiddens = []
        for model in self.models:
            out = model(input_ids=x, attention_mask=attention_mask, return_dict=True)
            last_hiddens.append(out.last_hidden_state[:, -1, :])  # (B, H)
        pooled = torch.mean(torch.stack(last_hiddens, dim=0), dim=0)  # (B, H)
        logits = self.classifier(pooled)  # (B, 1)
        return logits


def get_model(settings):

    if settings.model == "gpt2":
        config = AutoConfig.from_pretrained(
            "gpt2",
            vocab_size=settings.vocab_size,
            n_ctx=settings.context_length,
            bos_token_id=settings.BOS_TOKEN,
            n_layer=settings.n_layer,
            n_head=settings.n_head,
            n_embd=settings.n_embd,
        )
        model = EnsembleGPT2(config, settings.num_models, settings.ensemble_func)
    elif settings.model == "bloom":
        config = BloomConfig(
            vocab_size=settings.vocab_size,
            n_ctx=settings.context_length,
            bos_token_id=settings.BOS_TOKEN,
            n_layer=settings.n_layer,
            n_head=settings.n_head,
            hidden_size=settings.n_embd,
            attention_dropout=0.0,
        )
        model = EnsembleBLOOM(config, settings.num_models, settings.ensemble_func)
    elif settings.model == "llama3":
        # We use a standard Llama 3 config as a blueprint and then override it
        # to create a smaller ~1B parameter model for training from scratch.
        model_name_for_config = "meta-llama/Llama-3.2-1B"
        
        config = {
            "vocab_size": settings.vocab_size,
            "max_position_embeddings": settings.context_length,
            "bos_token_id": settings.BOS_TOKEN,
        }
        model = Llama3Coder(model_name=model_name_for_config, config_dict=config, model_config_overrides=None)
    elif settings.model == "deepseek":
        # We use the official 1.3B coder model config as our blueprint
        model_name_for_config = "deepseek-ai/deepseek-coder-1.3b-base"
        config = {
            "vocab_size": settings.vocab_size,
            "max_position_embeddings": settings.context_length,
            "bos_token_id": settings.BOS_TOKEN,
        }
        model = DeepSeekCoder(model_name=model_name_for_config, config_dict=config)
    elif settings.model == "qwen2":
        raise NotImplementedError("qwen 2 Coder implementation using text strategy not shown here.")
    elif settings.model in ["qwen1.7B", "qwen1.5B", 'qwen0.6B']:
        if settings.model == "qwen1.5B":
            raise Exception("No londer support Qwen2")
        elif settings.model == "qwen0.6B":
            model_name = "Qwen/Qwen3-0.6B"
        else:
            model_name = "Qwen/Qwen3-1.7B" # Using base model for config
        config = {
            "vocab_size": settings.vocab_size,
            "max_position_embeddings": settings.context_length,
            "bos_token_id": settings.BOS_TOKEN,
        }
        model = Qwen3Coder(model_name=model_name, config_dict=config)
    elif settings.model == "llama3_finetune":
        model = FinetuneLlama3Coder(
            model_name="meta-llama/Llama-3.2-1B", 
            num_layers_to_finetune=settings.num_layers_to_finetune
        )
    elif settings.model == "qwen3_finetune":
        model = FinetuneQwen3Coder(
            model_name="Qwen/Qwen3-1.7B",
            num_layers_to_finetune=settings.num_layers_to_finetune
        )
    elif settings.model == "deepseek_finetune":
        model = FinetuneDeepSeekCoder(
            model_name="deepseek-ai/deepseek-coder-1.3b-base",
            num_layers_to_finetune=settings.num_layers_to_finetune
        )
    else:
        raise Exception( "No model implementation found." )

    if "finetune" not in settings.model:
        model.to(settings.device)
    return model