from modules import *
from old_models import *
from torchvision.models import resnet18

def get_model(name, dataset_args, device):
    if "SimpleEmbeddings" in name:
        model = SimpleEmbeddings(embedding=False)
        dataset_args['simple'] = True
    elif name == "LayeredStreams2dec":
        model = LayeredStreams2dec()
    elif "lbpgruqa09" in name:
        model = LBPGRUQA09()
        dataset_args['lbp'] = True
    elif "lbpgruqa081" in name:
        model = LBPGRUQA081()
        dataset_args['lbp'] = True
    elif "lbpgruqa08" in name:
        model = LBPGRUQA08()
        dataset_args['lbp'] = True
    elif "lbpgruqa07" in name:
        model = LBPGRUQA07()
        dataset_args['lbp'] = True
    elif "lbpgruqa06" in name:
        model = LBPGRUQA06()
        dataset_args['lbp'] = True
    elif "lbpgruqa05" in name:
        model = LBPGRUQA05()
        dataset_args['lbp'] = True
    elif "lbpgruqa04" in name:
        model = LBPGRUQA04()
        dataset_args['lbp'] = True
    elif "lbpgruqa03" in name:
        model = LBPGRUQA03()
        dataset_args['lbp'] = True
    elif "lbpgruqa02" in name:
        model = LBPGRUQA02()
        dataset_args['lbp'] = True
    elif "lbpgruqa01" in name:
        model = LBPGRUQA01()
        dataset_args['lbp'] = True
    elif "gruqa21" in name:
        model = GRUQA21() 
    elif "gruqa20" in name:
        model = GRUQA20() 
    elif "gruqa10" in name:
        model = GRUQA10() 
    elif "gruqa01" in name:
        model = GRUQA01()
    elif "gruqa02" in name:
        model = GRUQA02()
    elif "gruqa" in name:
        model = GRUQA()
    elif name == "LayeredStreamsSecond":
        model = LayeredStreamsSecond()
    elif name == "two_streams_big2":
        model = TwoStreamIQABIG2()
    elif name == "cbam":
        model = CBAMBlock()
    elif name == "cbamcolor":
        model = CBAMColorBlock()
    elif name == "cbamcolor2":
        model = CBAMColorBlock2()
    elif name == "cbamcolor3":
        model = CBAMColorBlock3()
    elif name == "cbamcolor4":
        model = CBAMColorBlock4()
    elif name == "cbamcolor31":
        model = CBAMColorBlock31()
    elif name == "fourfeatures0":
        model = FourFeatures0()
        dataset_args['vif'] = True
        dataset_args['dlm'] = True
    elif name == "cbamcolor22":
        model = CBAMColorBlock22()
    elif "cbamcolor2reduced2" in name:
        model = CBAMColorBlock2Reduced2()
    elif "cbamcolor2reduced" in name:
        model = CBAMColorBlock2Reduced()
    elif "cbamcolor24hpa" in name:
        model = CBAMColorBlock24HPA()
    elif "cbamcolor24" in name:
        model = CBAMColorBlock24()
    elif "cbamcolor221" in name:
        model = CBAMColorBlock221()
    elif "cbamcolor231" in name:
        model = CBAMColorBlock231()
    elif "cbamcolor232" in name:
        model = CBAMColorBlock232()
    elif "cbamcolor233_complex" in name:
        model = CBAMColorBlock233(out_simple=False)
    elif "cbamcolor233" in name:
        model = CBAMColorBlock233()
    elif "cbamcolor23" in name:
        model = CBAMColorBlock23()
    elif "cbamcolor25" in name:
        model = CBAMColorBlock25()
    elif "cbamcolor26" in name:
        model = CBAMColorBlock26()
    elif "sinresblock2" in name:
        model = SinResBlock2()
    elif "sinresblock" in name:
        model = SinResBlock()
    elif "dctcolorblock0" in name:
        model = DCTColorBlock0()
        dataset_args['dct'] = True
    elif "dctcolorblock1" in name:
        model = DCTColorBlock1()
        dataset_args['dct'] = True
    elif "gatedlbpiqa2fusion" in name:
        model = GatedLBPIQA2Fusion()
        dataset_args['lbp'] = True
    elif "gatedlbpiqa4" in name:
        model = GatedLBPIQA4()
        dataset_args['lbp'] = True
    elif "gatedlbpiqa3" in name:
        model = GatedLBPIQA3()
        dataset_args['lbp'] = True
    elif "gatedlbpiqa2" in name:
        model = GatedLBPIQA2()
        dataset_args['lbp'] = True
    elif "gatedlbpiqa" in name:
        model = GatedLBPIQA()
        dataset_args['lbp'] = True
    elif "gatedsalbpiqa" in name:
        model = GatedSaLBPIQA()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
        dataset_args['lbp'] = True
    elif "saliqa" in name:
        model = SalIQA()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
    elif "salbpiqadown02" in name:
        model = SaLBPIQAdown02()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
        dataset_args['lbp'] = True
    elif "salbpiqadown2" in name:
        model = SaLBPIQAdown2()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
        dataset_args['lbp'] = True
    elif "salbpiqadown" in name:
        model = SaLBPIQAdown()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
        dataset_args['lbp'] = True
    elif "salbpiqa" in name:
        model = SaLBPIQA()
        saliency_model = resnet18(weights=True)
        saliency_model.to(device)
        dataset_args['saliency_model'] = saliency_model
        dataset_args['saliency'] = True
        dataset_args['lbp'] = True
    elif "lbpiqa" in name:
        model = LBPIQA()
        dataset_args['lbp'] = True

    return model, dataset_args