import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from transformers import CLIPTextModelWithProjection, T5EncoderModel
import numpy as np
import torch
from library.utils import setup_logging
from library import sd3_models

setup_logging()
import logging

logger = logging.getLogger(__name__)

class FullTrainingModule(torch.nn.Module):
    """Enable full parameter training for SD3"""
    def __init__(
        self,
        module_name: str,
        org_module: torch.nn.Module,
        multiplier: float = 1.0,
    ) -> None:
        super().__init__()
        self.module_name = module_name
        self.org_module = org_module
        self.multiplier = multiplier
        self.enabled = True
        
        # Store original parameters 
        self.weight = torch.nn.Parameter(org_module.weight.clone())
        if hasattr(org_module, 'bias') and org_module.bias is not None:
            self.bias = torch.nn.Parameter(org_module.bias.clone())
        else:
            self.bias = None

    def apply_to(self):
        # Replace original module parameters with trainable ones
        self.org_module.weight = self.weight
        if self.bias is not None:
            self.org_module.bias = self.bias

class FullTrainingNetwork(torch.nn.Module):
    """Network for full parameter training"""
    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
    SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
    
    def __init__(
        self,
        text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
        unet: sd3_models.MMDiT,
        multiplier: float = 1.0,
        train_text_encoder: bool = True,
        train_unet: bool = True,
    ) -> None:
        super().__init__()
        self.multiplier = multiplier
        self.train_text_encoder = train_text_encoder
        self.train_unet = train_unet

        # Create trainable modules
        self.text_encoder_modules: List[FullTrainingModule] = []
        self.unet_modules: List[FullTrainingModule] = []
        
        if train_text_encoder:
            logger.info("Setting up full training for Text Encoders")
            for i, text_encoder in enumerate(text_encoders):
                self._setup_modules(
                    text_encoder,
                    self.TEXT_ENCODER_TARGET_REPLACE_MODULE,
                    f"text_encoder_{i}",
                    self.text_encoder_modules
                )
                
        if train_unet:
            logger.info("Setting up full training for MMDiT")
            self._setup_modules(
                unet,
                self.SD3_TARGET_REPLACE_MODULE,
                "mmdit",
                self.unet_modules
            )

    def _setup_modules(self, root_module: torch.nn.Module, target_modules: List[str], prefix: str, module_list: List[FullTrainingModule]):
        """Set up trainable modules for a given root module"""
        for name, module in root_module.named_modules():
            if module.__class__.__name__ in target_modules:
                for child_name, child_module in module.named_modules():
                    if isinstance(child_module, (torch.nn.Linear, torch.nn.Conv2d)):
                        module_name = f"{prefix}.{name}.{child_name}".replace(".", "_")
                        trainable_module = FullTrainingModule(
                            module_name,
                            child_module,
                            self.multiplier
                        )
                        module_list.append(trainable_module)
                        self.add_module(module_name, trainable_module)

    def apply_to(self, text_encoders, unet):
        """Apply trainable parameters to all modules"""
        if self.train_text_encoder:
            logger.info(f"Applying full training to text encoders: {len(self.text_encoder_modules)} modules")
            for module in self.text_encoder_modules:
                module.apply_to()
                
        if self.train_unet:
            logger.info(f"Applying full training to MMDiT: {len(self.unet_modules)} modules")
            for module in self.unet_modules:
                module.apply_to()

    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
        """Prepare parameters for optimization"""
        self.requires_grad_(True)
        param_groups = []
        
        if self.text_encoder_modules and text_encoder_lr is not None:
            param_groups.append({
                "params": [p for m in self.text_encoder_modules for p in m.parameters()],
                "lr": text_encoder_lr
            })
            
        if self.unet_modules and unet_lr is not None:
            param_groups.append({
                "params": [p for m in self.unet_modules for p in m.parameters()],
                "lr": unet_lr
            })
            
        return param_groups

    def load_weights(self, file: str):
        """Load network weights"""
        if os.path.splitext(file)[1] == '.safetensors':
            from safetensors.torch import load_file
            state_dict = load_file(file)
        else:
            state_dict = torch.load(file, map_location='cpu')
            
        self.load_state_dict(state_dict)

    def save_weights(self, file: str, dtype=None, metadata=None):
        """Save network weights"""
        state_dict = self.state_dict()
        
        if dtype is not None:
            state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
            
        if metadata is None:
            metadata = {}
            
        if os.path.splitext(file)[1] == '.safetensors':
            from safetensors.torch import save_file
            save_file(state_dict, file, metadata)
        else:
            torch.save(state_dict, file)

    def enable_gradient_checkpointing(self):
        if hasattr(self, "enable_grad_checkpointing"):
            self.enable_grad_checkpointing()

    def prepare_grad_etc(self, text_encoder, unet):
        self.requires_grad_(True)

    def on_epoch_start(self, text_encoder, unet):
        self.train()

def create_network(
    multiplier: float,
    network_dim: Optional[int] = None,  # Not used in full training
    network_alpha: Optional[float] = None,  # Not used in full training
    vae=None,  # Not used in full training
    text_encoders=None,
    mmdit=None,
    train_text_encoder: bool = True, 
    train_unet: bool = True,
    **kwargs  # Additional kwargs are ignored for full training
) -> FullTrainingNetwork:
    return FullTrainingNetwork(
        text_encoders=text_encoders,
        unet=mmdit,
        multiplier=multiplier,
        train_text_encoder=train_text_encoder, 
        train_unet=train_unet,
    )