import numpy as np
import torch
import os
import clip
from torchvision.models import resnet34, ResNet34_Weights
from torch import nn, optim
from torch.nn import functional as F


def load_model(args):
    if args.model == "clip":
        cls_model = CLIPModel(args)
    elif args.model == "resnet":
        cls_model = ResNetModel(args)
    else:
        raise ValueError
    return cls_model


class SelectiveRecalNet(nn.Module):
    """Implements a feed-forward MLP."""

    def __init__(
        self,
        input_dim,
        sel_loss="bce",
        cov_loss="bce",
        scaling="temp",
        hidden_dim=64,
        num_layers=1,
        dropout=0.0,
        fixed_w=False,
        fixed_b=False,
        seed_w=0.75,
        seed_b=0.75
    ):
        super(SelectiveRecalNet, self).__init__()
        
        if scaling == "platt":
            self.init_platt_scaler(fixed_w, fixed_b, seed_w, seed_b)
            self.calibrator = self.platt_calibrator
        elif scaling == "temp":
            self.init_temp_scaler(fixed_w, seed_w)
            self.calibrator = self.temp_calibrator
        else:
            raise NotImplementedError
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.sel_loss = sel_loss
        self.cov_loss = cov_loss
        self.scaling = scaling

        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_layers):
            layers.append(nn.Dropout(dropout))
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.extend([nn.Dropout(dropout), nn.Linear(hidden_dim, 1)])
        self.net = nn.Sequential(*layers)

        
    def init_platt_scaler(self, fixed_w, fixed_b, seed_w, seed_b):

        self.w = nn.Parameter((torch.ones(1) * seed_w), requires_grad=(not fixed_w))
        self.b = nn.Parameter((torch.ones(1) * seed_b), requires_grad=(not fixed_w))  
        
        
    def init_temp_scaler(self, fixed_w, seed_w):

        self.T = nn.Parameter((torch.ones(1) * seed_w), requires_grad=(not fixed_w))

    
    def selector(self, phi):
        return self.net(phi).view(-1)
    
    
    def platt_calibrator(self, logits, eps):

        x = torch.max(F.softmax(logits, -1).cuda(), -1)[0]
        x = x.clamp(min=eps, max=(1.0 - eps))
        x = torch.log(x / (1 - x))
        x = x * self.w + self.b
        x = torch.sigmoid(x)
        return x
    
    def temp_calibrator(self, logits, eps):
        
        temperature = self.T.unsqueeze(1).expand(logits.size(0), logits.size(1))
        x = F.softmax(logits / temperature, -1)
        return x
        

    def forward(self, logits, phi=None, eps = 1e-6):
        
        x = self.calibrator(logits, eps)
        
        if phi is not None:
            g = self.selector(phi)
        else:
            g = None
        
        return x, g


class ResNetModel():
    
    def __init__(self, args):
        
        self.weights = ResNet34_Weights.DEFAULT
        self.preprocess = self.weights.transforms()
        self.model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.feature_extractor = torch.nn.Sequential(*list(self.model.children())[:-1])
        self.classifier = torch.nn.Sequential(*list(self.model.children())[-1:])
        self.model.cuda().eval()

    def forward(self, images):
        
        logits = self.model(images)
        features = self.feature_extractor(images).squeeze(-1).squeeze(-1)
        logits = self.classifier(features)
        # features = features
        
        return logits, features
    
    
class CLIPModel():
    
    def __init__(self, args):
        
        self.model, self.preprocess = clip.load("ViT-B/32")
        self.model.cuda().eval()
        self.input_resolution = self.model.visual.input_resolution
        self.context_length = self.model.context_length
        self.vocab_size = self.model.vocab_size
        
        if args.dataset == "cifar-100":
            self.weights = torch.load("../models/clip/cifar-100/zeroshot_weights.pt")
        elif args.dataset == "imagenet-v2":
            self.weights = torch.load("../models/clip/imagenet-v2/zeroshot_weights.pt")
        else:
            raise ValueError
        self.weights.cuda()

    def forward(self, images):
        
        image_features = self.model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # measure accuracy
        logits = 100. * image_features @ self.weights
        
        return logits, image_features