import torch
from torchvision.models import resnet50
from collections import OrderedDict


'''
model definitions
'''


class FCNet(torch.nn.Module):
    def __init__(self, num_feats, num_classes):
        super(FCNet, self).__init__()
        self.fc = torch.nn.Linear(num_feats, num_classes)

    def forward(self, x):
        x = self.fc(x)
        return x


class ImageClassifier(torch.nn.Module):

    def __init__(self, P):

        super(ImageClassifier, self).__init__()
        print('initializing image classifier')

        self.arch = P['arch']

        if self.arch == 'resnet50':
            # configure feature extractor:
            if P['use_pretrained']:
                print('feature extractor: imagenet pretrained')
                feature_extractor = resnet50(weights=P['resnet50_weights'])
            else:
                print('feature extractor: randomly initialized')
                feature_extractor = resnet50()

            # feature_extractor = torch.nn.Sequential(OrderedDict(list(feature_extractor.named_children())[:-1]))

            feature_extractor.fc = torch.nn.Linear(2048, P['feat_dim'], bias=True)

            for param in feature_extractor.parameters():
                param.requires_grad = False

            if P['freeze_feature_extractor']:
                print('feature extractor frozen')

            else:
                print('feature extractor partially trainable')
                print('fine-tuning mode: {}'.format(P['fine_tune']))

                if P['fine_tune'] == 'last_layer':
                    for param in feature_extractor.layer4[2].parameters():
                        param.requires_grad = True

                elif P['fine_tune'] == 'last_block':
                    for param in feature_extractor.layer4.parameters():
                        param.requires_grad = True

                else:
                    raise ValueError('Fine-tune mode not implemented.')

                for param in feature_extractor.fc.parameters():
                    param.requires_grad = True

            self.feature_extractor = feature_extractor

            # configure the final fully connected layer:
            print('linear classifier layer: randomly initialized')
            linear_classifier = torch.nn.Linear(P['feat_dim'], P['num_classes'], bias=True)

            self.linear_classifier = linear_classifier

        else:
            raise ValueError('Architecture not implemented.')

    def forward(self, x, act_flag):
        # x is a batch of images
        feats = self.feature_extractor(x)

        act_feats = torch.tanh(feats)

        logits = self.linear_classifier(act_feats)

        return logits, act_feats if act_flag else feats

