from typing import Optional ,Dict ,List ,Union ,Any 
from dataclasses import dataclass 
from contextlib import contextmanager 

import torch 
from torch import nn 

import lightning as L 
from transformers import get_scheduler ,PreTrainedModel 
from ..modules .ema import EMAModel 
from ..modules .lightning .parallelism import apply_tensor_parallelism 
from ..utils import (
instantiate_from_config ,
instantiate_model_from_config ,
instantiate_optimizer_from_config ,
)


@dataclass 
class SampleOutput :
    output_ids :Union [torch .Tensor ,List [int ]]
    texts :List [str ]
    logits :Optional [torch .Tensor ]=None 


def get_parameter_names (model ,forbidden_layer_types ):
    """
    Returns the names of the model parameters that are not inside a forbidden layer.
    """
    result =[]
    for name ,child in model .named_children ():
        result +=[
        f"{name }.{n }"
        for n in get_parameter_names (child ,forbidden_layer_types )
        if not isinstance (child ,tuple (forbidden_layer_types ))
        ]

    result +=list (model ._parameters .keys ())
    return result 


def get_decay_parameter_names (model )->List [str ]:
    """
    Get all parameter names that weight decay will be applied to

    Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
    apply to those modules since this function only filter out instance of nn.LayerNorm
    """
    decay_parameters =get_parameter_names (model ,[nn .LayerNorm ])
    no_decay =["bias"]
    decay_parameters =[
    name for name in decay_parameters if not any (nd in name for nd in no_decay )
    ]
    return decay_parameters 


def get_optimizer_params (
model :nn .Module ,
loss_fn :Optional [nn .Module ]=None ,
weight_decay :float =0.0 ,
ignore_parameters :Optional [List [str ]]=None ,
):
    def is_ignore_parameters (name ):
        return (
        any ([ignore_n not in name for ignore_n in ignore_parameters ])
        if ignore_parameters 
        else True 
        )


    param_optimizer =list (model .named_parameters ())
    decay_parameters =get_decay_parameter_names (model )

    if loss_fn is not None :
        param_optimizer +=list (loss_fn .named_parameters ())
        decay_parameters +=get_decay_parameter_names (loss_fn )

    optimizer_grouped_parameters =[
    {
    "params":[
    p 
    for n ,p in model .named_parameters ()
    if (n in decay_parameters and p .requires_grad )
    and is_ignore_parameters (n )
    ],
    "weight_decay":weight_decay ,
    },
    {
    "params":[
    p 
    for n ,p in model .named_parameters ()
    if (n not in decay_parameters and p .requires_grad )
    and is_ignore_parameters (n )
    ],
    "weight_decay":0.0 ,
    },
    ]

    return optimizer_grouped_parameters 


class LM (L .LightningModule ):
    support_loss_parallel :bool =False 

    def __init__ (
    self ,
    model_config ,
    loss_config ,
    optimizer_config ,
    scheduler_config :Optional [dict ]=None ,
    gradient_checkpointing :bool =False ,
    use_ema :bool =False ,
    ema_decay :float =0.999 ,
    torch_compile :bool =False ,
    not_hf_pretrained :bool =False ,
    ):
        super ().__init__ ()

        self .save_hyperparameters ()

        self .tokenizer =None 

    def on_train_start (self )->None :
        if hasattr (self .model ,"init_weights"):
            self .model .init_weights ()

    def configure_model (self ):
        self .model :PreTrainedModel =instantiate_model_from_config (
        self .hparams ["model_config"],
        not_hf_pretrained =self .hparams ["not_hf_pretrained"],
        ).train ()
        if self .hparams ["gradient_checkpointing"]:
            self .model .gradient_checkpointing_enable (
            gradient_checkpointing_kwargs ={"use_reentrant":False }
            )
            self .print ("Gradient checkpointing enabled.")
        apply_tensor_parallelism (
        self .model ,
        device_mesh =self .device_mesh ,
        loss_parallel =self .support_loss_parallel ,
        )
        if self .hparams ["use_ema"]:

            self .model_ema =EMAModel (
            self .model .parameters (),decay =self .hparams ["ema_decay"]
            )
            self .print ("EMA enabled.")
        if self .hparams ["torch_compile"]:
            self .print ("Compiling the model... (takes a ~minute)")
            self .model =torch .compile (model =self .model )
        if not hasattr (self ,"loss_fn"):
            self .loss_fn =instantiate_from_config (self .hparams ["loss_config"])

    def get_batch_size (self ,batch :Dict [str ,Any ])->int :
        return batch ["input_ids"].size (0 )

    def forward (self ,batch :Dict [str ,Any ],**kwargs )->dict :
        return self .loss_fn (model =self .model ,batch =batch ,**kwargs )

    def shared_step (self ,batch ,step ="train",prefix ="",**kwargs ):
        outputs =self (batch ,**kwargs )
        loss =outputs ["loss"]
        loss_dict ={
        f"{step }/{prefix }{k }":v for k ,v in outputs .items ()if v is not None 
        }
        return loss ,loss_dict 

    def training_step (self ,batch :Dict [str ,Any ],batch_idx ):
        batch_size =self .get_batch_size (batch )
        loss ,loss_dict =self .shared_step (batch )
        self .log_dict (loss_dict ,batch_size =batch_size ,prog_bar =True )

        self .log (
        "global_step",
        self .global_step ,
        prog_bar =True ,
        logger =True ,
        on_step =True ,
        on_epoch =False ,
        )
        return loss 

    def on_train_batch_end (self ,outputs ,batch ,batch_idx ):
        if self .hparams ["use_ema"]:
            self .model_ema .step (self .model .parameters ())

    @contextmanager 
    def ema_scope (self ,context =None ):
        if self .hparams ["use_ema"]:
            self .model_ema .store (self .model .parameters ())
            self .model_ema .copy_to (self .model .parameters ())
            if context is not None :
                self .print (f"{context }: Switched to EMA weights")
        try :
            yield None 
        finally :
            if self .hparams ["use_ema"]:
                self .model_ema .restore (self .model .parameters ())
                if context is not None :
                    self .print (f"{context }: Restored training weights")

    def validation_step (self ,batch :Dict [str ,Any ],batch_idx ):
        batch_size =self .get_batch_size (batch )
        with self .ema_scope (context ="Validation"):
            loss ,loss_dict =self .shared_step (batch ,step ="val")
            metrics =self .compute_metrics (batch ,loss_dict )
            if metrics :
                metrics ={f"val/{k }":v for k ,v in metrics .items ()}
                loss_dict .update (metrics )
        self .log_dict (
        loss_dict ,batch_size =batch_size ,prog_bar =True ,logger =True ,sync_dist =True 
        )
        return loss 

    def configure_optimizers (self ):

        optimizer_config =self .hparams ["optimizer_config"]
        if "params"in optimizer_config and optimizer_config ["params"].get (
        "weight_decay"
        ):
            weight_decay =optimizer_config ["params"]["weight_decay"]
        else :
            weight_decay =0.0 
        loss_fn =self .loss_fn if isinstance (self .loss_fn ,nn .Module )else None 
        params =get_optimizer_params (self .model ,loss_fn ,weight_decay =weight_decay )
        optimizer =instantiate_optimizer_from_config (optimizer_config ,params )

        scheduler_config =self .hparams .get ("scheduler_config")
        num_warmup_steps =scheduler_config .pop ("num_warmup_steps",0 )
        scheduler =get_scheduler (
        optimizer =optimizer ,
        num_training_steps =self .trainer .estimated_stepping_batches ,
        num_warmup_steps =num_warmup_steps ,
        **scheduler_config ,
        )
        self .print (
        f"Setting up scheduler (estimated_stepping_batches: {self .trainer .estimated_stepping_batches })..."
        )
        scheduler =[
        {
        "scheduler":scheduler ,
        "interval":"step",
        "frequency":1 ,
        }
        ]
        return [optimizer ],scheduler 

    def sample (self ,num_samples :int ,**kwargs )->SampleOutput :
        """
        Generate samples from the model
        Ema weights will be used if enabled
        """
        raise NotImplementedError 

    def compute_metrics (
    self ,batch :Dict [str ,Any ],loss_dict :dict 
    )->Optional [Dict [str ,float ]]:
        """
        Compute metrics against generated samples
        Ema weights will be used if enabled
        """
        return None 
