import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

torch.random.manual_seed(0)


from data import load_data
from functools import partial
import torchvision
from tqdm import tqdm

import argparse
import clip 

class LinearProbe(nn.Module):
    def __init__(self, base_model, output_size):
        super(LinearProbe, self).__init__()
        self.base_model = base_model
        self.linear = nn.Linear(768, output_size)

    def forward(self, x):
        with torch.no_grad():
            feats = self.base_model.encode_image(x)
        return self.linear(feats)

def load_data(dataset="cifar10", subset_size=20, preprocess=None):

    '''
    Function to load data from cifar10 (or cifar100)
    '''

    print("loading dataset", dataset)
    if dataset == "cifar10":
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=preprocess)
        num_classes = 10

    elif dataset == "cifar100":
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=preprocess)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=preprocess)
        num_classes = 100
    
    # take subset of train dataset of subset_size images per class
    indices = []
    labels = np.array(trainset.targets)
    for i in range(num_classes):
        indices += np.where(labels == i)[0][:subset_size].tolist()
    
    trainset_subset = torch.utils.data.Subset(trainset, indices)

    # convert to dataloaders
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    trainsubloader = torch.utils.data.DataLoader(trainset_subset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    return trainloader, trainsubloader, testloader, num_classes

def load_data_random_labels(dataset="cifar10", preprocess=None):

    print("loading dataset", dataset)
    if dataset == "cifar10":
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess)
        num_classes = 10

    elif dataset == "cifar100":
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=preprocess)
        num_classes = 100

    # loop through and randomly assign label
    for i in range(len(trainset)):
        trainset.targets[i] = np.random.randint(num_classes)

    # convert to dataloader
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
    
    return trainloader, num_classes

def compute_pac_bayes_linear_fsl(args):
    # trainloader, testloader, num_classes = load_data(args.data, preprocess = preprocess)

    # load L-14 model
    clip_model, preprocess = clip.load("ViT-L/14", device="cuda")
    clip_model.eval()
    clip_model = clip_model.float()

    print(preprocess)

    trainloader, trainsubloader, testloader, num_classes = load_data(dataset=args.data, preprocess=preprocess)

    for param in clip_model.parameters():
        param.requires_grad = False

    # load linear probe
    linear_probe = LinearProbe(clip_model, num_classes)
    linear_probe = linear_probe.to("cuda")
    
    initial_weights = linear_probe.linear.weight.data.cpu().numpy()

    # train linear probe
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(linear_probe.linear.parameters(), lr=0.01)

    # train for 10 epochs
    epochs = 10
    for epoch in tqdm(range(epochs), total=epochs):

        running_loss = 0.0
        
        for i, data in enumerate(trainsubloader):
            inputs, labels = data
            # inputs = preprocess(inputs).to("cuda")
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")

            optimizer.zero_grad()

            outputs = linear_probe(inputs)
            loss = criterion(outputs, labels)
            # print(loss, outputs)
            
            loss.backward()
            # print(linear_probe.linear.weight.grad)
            running_loss += loss.item()
            optimizer.step()

        print("Average Loss", running_loss / len(trainloader))
    
    # eval and get test performance
    tcorrect = 0
    ttotal = 0
    with torch.no_grad():
        for data in trainsubloader:
            inputs, labels = data
            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            outputs = linear_probe(inputs)
            _, predicted = torch.max(outputs.data, 1)
            ttotal += labels.size(0)
            tcorrect += (predicted == labels).sum().item()
    print("Train Accuracy", tcorrect / ttotal)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")

            outputs = linear_probe(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            break
    print("Test Accuracy", correct / total)


    # get weights of model
    weights = linear_probe.linear.weight.data.cpu().numpy()

    # compute pac bayes bound
    # define prior as gaussian over weights 
    # compute complexity term - negative log likelihood of prior
    
    variances = np.arange(0.1, 10, 20000)
    bounds = []
    weight_diff = np.sum(np.square(weights - initial_weights))

    delta = 0.01
    # compute bound term 
    
    for v in variances:
        b = np.sqrt( (weight_diff / (4 * (v**2)) + np.log(20 * num_classes / delta) + 10) / (20 * num_classes - 1)    )
        bounds.append(b)
    bound_term = np.min(bounds)
    print("Bound", bound_term)
    print("PAC Bayes Bound", (1 - tcorrect / ttotal) + bound_term)

def compute_train_acc_random(args):

    # set random seeds
    np.random.seed(0)
    torch.manual_seed(0)

    # load L-14 model
    clip_model, preprocess = clip.load("ViT-L/14", device="cuda")
    clip_model.eval()
    clip_model = clip_model.float()

    trainloader, num_classes = load_data_random_labels(dataset=args.data, preprocess=preprocess)

    for param in clip_model.parameters():
        param.requires_grad = False

    # load linear probe
    linear_probe = LinearProbe(clip_model, num_classes)
    linear_probe = linear_probe.to("cuda")

    # train linear probe
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(linear_probe.linear.parameters(), lr=0.001)

    # train for 10 epochs
    epochs = 5
    for epoch in tqdm(range(epochs), total=epochs):

        running_loss = 0.0
        
        for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
            inputs, labels = data
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")

            optimizer.zero_grad()

            outputs = linear_probe(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            running_loss += loss.item()
            optimizer.step()

        print("Average Loss", running_loss / len(trainloader))

        # eval and get train performance
        tcorrect = 0
        ttotal = 0

        with torch.no_grad():
            for data in trainloader:
                inputs, labels = data
                inputs, labels = inputs.to("cuda"), labels.to("cuda")
                outputs = linear_probe(inputs)
                _, predicted = torch.max(outputs.data, 1)
                ttotal += labels.size(0)
                tcorrect += (predicted == labels).sum().item()
        print("Train Accuracy", tcorrect / ttotal)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--data', default='cifar10')

    args = parser.parse_args()
    print(args)

    # compute_pac_bayes_linear_fsl(args)
    compute_train_acc_random(args)



