from torch import nn
import torchmetrics
import torchvision
import peft
import transformers
import sys
import pytorch_lightning as PL
import tqdm
from ModifiedModel import MPCmodel

############## Run Model Wrapper ############
class MyTQDMProgressBar(PL.callbacks.TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = tqdm.tqdm(
            desc=self.validation_description,
            position=0,  
            disable=self.is_disabled,
            leave=True,  
            dynamic_ncols=True,
            file=sys.stdout,
        )
        return bar

def one_run_wrapper(create_model_func):
    """
    A decorator function that wraps the `run_model` function.

    Parameters:
    - create_model_func: A function that creates a model.

    Returns:
    - run_model: The wrapped function that runs the model.

    """
    def run_model(name,dataconfig,config,Modeltype=MPCmodel,**kwargs):
        kwargs['modify_batch']=40
        model=create_model_func(config=config,dataconfig=dataconfig,Modeltype=Modeltype,**kwargs)
        logger=PL.loggers.CSVLogger('./logs',name=name)
        callbacks=[MyTQDMProgressBar(refresh_rate=10),] if config.get('progress_bar',True) else []
        callbacks+=kwargs.get('callbacks',[])
        if kwargs.get('enable_checkpoint',False):
            callbacks+=[PL.callbacks.ModelCheckpoint(
                    dirpath=f'./lightning_logs/{name}/',
                    filename='last_epoch_model',
                    save_last=True,
                )]
        trainer = PL.Trainer(
            log_every_n_steps=1,
            num_sanity_val_steps=2,
            accelerator='gpu',
            devices=[config['gpu']],
            max_epochs=config['epochs'],
            callbacks=callbacks,
            enable_checkpointing=kwargs.get('enable_checkpoint',False),
            logger=logger)
        trainer.fit(model=model, train_dataloaders=dataconfig['train_dataloader'],
                    val_dataloaders=dataconfig['test_dataloader'],
                   ckpt_path=kwargs.get('checkpoint',None))
    return run_model

############### ResNet50 ################
class CNNlossblock(nn.Module):
    def __init__(self, channel, classes):
        super().__init__()
        self.pooling=nn.AdaptiveAvgPool2d([1,1])
        self.fc=nn.Linear(channel, classes)
    def forward(self, x):
        x=self.pooling(x)
        return self.fc(x.view(x.size(0),-1))
    
class ResNet50(nn.Module):
    def __init__(self,classes,):
        super().__init__()
        model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        self.layer0=nn.Sequential(
            model.conv1,
            model.bn1,
            model.relu,
            model.maxpool
        )
        self.layer1_blocks=nn.ModuleList([b for b in model.layer1.children()])
        self.layer2_blocks=nn.ModuleList([b for b in model.layer2.children()])
        self.layer3_blocks=nn.ModuleList([b for b in model.layer3.children()])
        self.layer4_blocks=nn.ModuleList([b for b in model.layer4.children()])
        self.head = CNNlossblock(model.fc.in_features,classes)
 
    def forward(self, x):
        x=self.layer0(x)
        for b in self.layer1_blocks:
            x=b(x)
        for b in self.layer2_blocks:
            x=b(x)
        for b in self.layer3_blocks:
            x=b(x)
        for b in self.layer4_blocks:
            x=b(x)
        # x = x.view(x.size(0), x.size(1))
        return self.head(x)
    
def Create_ResNet50(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    """
    Create a modified ResNet50 model for training.

    Args:
        config (dict): Configuration parameters for the model.
        dataconfig (dict): Configuration parameters for the data.
        Modeltype (class, optional): Model type to use for modification. Defaults to MPCmodel.
        **kwargs: Additional keyword arguments.

    Returns:
        MPCmodel: Modified ResNet50 model for training.
    """
    target_class=len(dataconfig['training_data'].classes)
    model=ResNet50(target_class)
    mpcmodel_modified=Modeltype(model,config['horizon'],config['stride'],
                     [CNNlossblock(model.layer0[1].num_features,classes=target_class)]+\
                  [CNNlossblock(b.bn3.num_features,classes=target_class) for b in model.layer1_blocks]+\
                  [CNNlossblock(b.bn3.num_features,classes=target_class) for b in model.layer2_blocks]+\
                  [CNNlossblock(b.bn3.num_features,classes=target_class) for b in model.layer3_blocks]+\
                  [CNNlossblock(b.bn3.num_features,classes=target_class) for b in model.layer4_blocks[:-1]],
                     dataconfig['loss_fn'],
                     metrics=torchmetrics.MetricCollection({
                'acc':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,),
                'acc5':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,top_k=5),}),
                            lr=config['learning_rate'],
                            optimizer=config.get('optimizer','sgd'),
                            momentum=config.get('momentum',None),
                            lambda_g=config.get('lambda_modify',0.),**kwargs)
    return mpcmodel_modified

@one_run_wrapper
def ResNet50_one_run(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    return Create_ResNet50(config,dataconfig,Modeltype=Modeltype,**kwargs)
############### ResNet50 ################

############### LoKr ResNet50 ################
class Lokr_CNN_block(nn.Module):
    def __init__(self,block):
        super().__init__()
        self.block=block
    def forward(self, x):
        return self.block(x)
    def requires_grad_(self, requires_grad=True):
        for n,p in self.named_parameters():
            if 'lokr' in n:
                p.requires_grad_(requires_grad)
        return self
    
class LoKr_ResNet50(nn.Module):
    """
    LoKr_ResNet50 is a custom ResNet50 model with LoKr fine-tuning.
    """
    def __init__(self, classes, r=1, alpha=4):
        """
        Initializes LoKr_ResNet50 model.
        Args:
            classes (int): Number of output classes.
            r (int): Value for r parameter for LoKr.
            alpha (int): Value for alpha parameter for LoKr.
        """
        
        super().__init__()
        model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        config = peft.LoKrConfig(
            target_modules=['conv1','conv2','conv3','downsample.0'],
            r=r,  
            alpha=alpha,
                )
        peft_model = peft.get_peft_model(model, config)
        self.layer0=Lokr_CNN_block(nn.Sequential(
            peft_model.conv1,
            peft_model.bn1,
            peft_model.relu,
            peft_model.maxpool
        ))
        self.layer1_blocks=nn.ModuleList([Lokr_CNN_block(b) for b in peft_model.layer1.children()])
        self.layer2_blocks=nn.ModuleList([Lokr_CNN_block(b) for b in peft_model.layer2.children()])
        self.layer3_blocks=nn.ModuleList([Lokr_CNN_block(b) for b in peft_model.layer3.children()])
        self.layer4_blocks=nn.ModuleList([Lokr_CNN_block(b) for b in peft_model.layer4.children()])
        self.head = CNNlossblock(model.fc.in_features,classes)
        
        self.head.requires_grad_(True)
 
    def forward(self, x):
        x=self.layer0(x)
        for b in self.layer1_blocks:
            x=b(x)
        for b in self.layer2_blocks:
            x=b(x)
        for b in self.layer3_blocks:
            x=b(x)
        for b in self.layer4_blocks:
            x=b(x)
        # x = x.view(x.size(0), x.size(1))
        return self.head(x)
    
def Create_LoKr_ResNet50(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    target_class=len(dataconfig['training_data'].classes)
    model=LoKr_ResNet50(target_class,config.get('r',1),config.get('alpha',4))
    mpcmodel_modified=Modeltype(model,config['horizon'],config['stride'],
                 [CNNlossblock(model.layer0.block[1].num_features,classes=target_class)]+\
              [CNNlossblock(b.block.bn3.num_features,classes=target_class) for b in model.layer1_blocks]+\
              [CNNlossblock(b.block.bn3.num_features,classes=target_class) for b in model.layer2_blocks]+\
              [CNNlossblock(b.block.bn3.num_features,classes=target_class) for b in model.layer3_blocks]+\
              [CNNlossblock(b.block.bn3.num_features,classes=target_class) for b in model.layer4_blocks[:-1]],
                 dataconfig['loss_fn'],
                 metrics=torchmetrics.MetricCollection({
            'acc':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,),
            'acc5':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,top_k=5),}),
                        lr=config['learning_rate'],
                        optimizer=config.get('optimizer','sgd'),
                        momentum=config.get('momentum',None),
                        lambda_g=config.get('lambda_modify',0.),**kwargs)
    return mpcmodel_modified

@one_run_wrapper
def LoKr_ResNet50_one_run(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    return Create_LoKr_ResNet50(config,dataconfig,Modeltype=Modeltype,**kwargs)
############### LoKr ResNet50 ################

############### LoRA ViT-b16 ################
class Vit_block(nn.Module):
    def __init__(self,block,lora=True):
        super().__init__()
        self.block=block
        self.lora=lora
    def forward(self, x):
        return self.block(x)[0]
    def requires_grad_(self, requires_grad=True):
        if self.lora:
            for n,p in self.named_parameters():
                if 'lora' in n:
                    p.requires_grad_(requires_grad)
            return self
        else:
            return super().requires_grad_(requires_grad)
    
class SelectFirstStep(nn.Module):
    def forward(self, x):
        # x.shape=(batch_size, sequence_length, features)
        return x[:, 0]
    
class Vit_B16(nn.Module):
    def __init__(self,classes,r=1,alpha=4):
        super().__init__()
        url="google/vit-base-patch16-224-in21k"
        model = transformers.AutoModelForImageClassification.from_pretrained(
            url,
            num_labels=classes,
        )
        config = peft.LoraConfig(
            target_modules=['query','value'],
            r=r,  # LoRA 的秩
            lora_alpha=alpha,
        )
        peft_model = peft.get_peft_model(model, config)
        tmpmodel=peft_model.base_model.model
        self.stem=tmpmodel.vit.embeddings
        self.layer1_blocks=nn.ModuleList([Vit_block(l) for l in tmpmodel.vit.encoder.layer])
        self.head = nn.Sequential(
            tmpmodel.vit.layernorm,
            SelectFirstStep(),
            tmpmodel.classifier,
        )
        self.head.requires_grad_(True)
 
    def forward(self, x):
        x=self.stem(x)
        for b in self.layer1_blocks:
            x=b(x)
        return self.head(x)

    
class ViTlossblock(nn.Module):
    def __init__(self, channel_in, classes):
        super().__init__()
        self.layernorm=nn.LayerNorm([channel_in,],eps=1e-12)
        self.fc=nn.Linear(channel_in, classes,)
        self.fc.bias.data.fill_(0.)
    def forward(self, x):
        x=self.layernorm(x)
        return self.fc(x[:,0])
    
def Create_ViTb16(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    target_class=len(dataconfig['training_data'].classes)
    model=Vit_B16(target_class,config.get('r',1),config.get('alpha',4),)
    mpcmodel_modified=Modeltype(model,config['horizon'],config['stride'],
                      [ViTlossblock(b.block.output.dense.out_features,classes=target_class) for b in model.layer1_blocks[:-1]],
                         dataconfig['loss_fn'],
                         metrics=torchmetrics.MetricCollection({
                    'acc':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,),
                    'acc5':torchmetrics.Accuracy(task='multiclass',num_classes=target_class,top_k=5),}),
                        lr=config['learning_rate'],
                        optimizer=config.get('optimizer','sgd'),
                        momentum=config.get('momentum',None),
                        lambda_g=config.get('lambda_modify',0.),**kwargs
                              )
    return mpcmodel_modified

@one_run_wrapper
def ViTb16_one_run(config,dataconfig,Modeltype=MPCmodel,**kwargs):
    return Create_ViTb16(config,dataconfig,Modeltype=Modeltype,**kwargs)
############### LoRA ViT-b16 ################
