import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# from models.wideresnet import WideResNet
# from models.preactresnet import PreActResNet18
from models.simple_net import simple_net, simpler_net
from models.vgg import vgg11
# from models.simpleCNN import simpleCNN

# upper_limit, lower_limit = 1., 0.
# DEBUG = 1#False

def normalize(X):
    cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
    cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

    mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
    std = torch.tensor(cifar10_std).view(3,1,1).cuda()
    return (X - mu)/std


class ModelWrapper(nn.Module):
    def __init__(self, model, normalization = False):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.normalization = normalization
        cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
        cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255

        self.mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
        self.std = torch.tensor(cifar10_std).view(3,1,1).cuda()

    def normalize(self, x):
        return (x - self.mu)/self.std
        
    def forward(self, x):
        if self.normalization:
            x = self.normalize(x)
        return self.model(x)
    
# def unnormalize(X):
#     return std*X + mu


def clamp(X, l, u):
    return torch.max(torch.min(X, u), l)

def getData(name='cifar10', train_bs=128, test_bs=1000, data_augmentation = True, simple=False, shuffle = False, subset = False, normalization= False):
    """
    Get the dataloader
    subset: indices for subset
    """
    if simple:
        shuffle = False
    print('shuffle,',shuffle)
    if name == 'cifar10':
        if data_augmentation == True:
            if normalization:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding = 4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding = 4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor()])

        else:
            if normalization:
                transform_train = transforms.Compose([transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010)),
                                   ])
            else:
                transform_train = transforms.Compose([transforms.ToTensor()])

        if normalization:
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
             transform_test = transforms.Compose([
                transforms.ToTensor()
             ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        if simple:
            trainset1 = torch.utils.data.Subset(trainset, range(train_bs))
            testset1 = torch.utils.data.Subset(testset, range(test_bs))
            
        if subset:
            subset=torch.utils.data.Subset(trainset, subset)
            train_loader = torch.utils.data.DataLoader(subset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, pin_memory = True)
            return train_loader, None
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, pin_memory = True)

        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, pin_memory = True)
        
        if simple:
            train_loader1 = torch.utils.data.DataLoader(trainset1, batch_size = train_bs, shuffle = shuffle, num_workers = 4, pin_memory = True)
            test_loader1 = torch.utils.data.DataLoader(testset1, batch_size = test_bs, shuffle = False, num_workers = 4, pin_memory = True)
        
        
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor()
#             transforms.Normalize((0.4914, 0.4822, 0.4465),
#                                  (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor()
#             transforms.Normalize((0.4914, 0.4822, 0.4465),
#                                  (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        if simple:
            trainset1 = torch.utils.data.Subset(trainset, range(train_bs))
            testset1 = torch.utils.data.Subset(testset, range(test_bs))
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
        
        if simple:
            train_loader1 = torch.utils.data.DataLoader(trainset1, batch_size = train_bs, shuffle = shuffle, num_workers = 4, pin_memory = True)
            test_loader1 = torch.utils.data.DataLoader(testset1, batch_size = test_bs, shuffle = False, num_workers = 4, pin_memory = True)
        
    if name == 'mnist':
        mnist_train = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())
        mnist_test = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())
        train_loader = torch.utils.data.DataLoader(mnist_train,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, pin_memory = True)

        test_loader = torch.utils.data.DataLoader(mnist_test,
                                                  batch_size=test_bs,
                                                  shuffle=False, pin_memory = True)
        std = [1.0]
        mean = [0.0]
        train_loader.std = std
        test_loader.std = std
        train_loader.mean = mean
        test_loader.mean = mean
        return train_loader, test_loader

        
    if simple:
        return train_loader1, test_loader1
    else:
        return train_loader, test_loader
    

def test(model, test_loader, cuda=True, print_opt=True):
    """
    Get the test performance
    """
    model.eval()
    correct = 0
    total_num = 0
    test_loss = 0
    criterion = nn.CrossEntropyLoss()
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        output = model((data))
        
        loss = criterion(output, target)
        test_loss += loss.item() * target.size()[0]
        
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        total_num += len(data)
    if print_opt:
        print('testing_correct: ', correct / total_num, '\n')
    return correct / total_num, test_loss / total_num

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model(model_name, width_factor=None, depth=None, residual=None, batch_norm=None,dataset_type='cifar10'):
#     if model_name == 'PreActResNet18' or model_name == 'parn':
#         model = PreActResNet18()
#     elif model_name == 'WideResNet':
#         model = WideResNet(34, 10, widen_factor=width_factor, dropRate=0.0)
#     elif model_name == 'WideResNet7016':
#         model = WideResNet(70, 10, widen_factor=16, dropRate=0.0) ######
    if model_name == 'simple-net':
        model = simple_net(in_channel=3, widen_factor=1, n_fc=512, num_classes=10)
    elif model_name == 'simpler-net':
        model = simple_net(in_channel=3, widen_factor=1, n_fc=100, num_classes=10)
#     elif model_name == 'simpleCNN': ### Jastrzebski et al. 2019
#         model = simpleCNN(in_channel=3, widen_factor=1, n_fc=128, num_classes=10)
#     elif model_name == 'resnet':
#         model = resnet(num_classes=10,
#                        depth=depth,
#                        residual_not=residual,
#                        batch_norm_not=batch_norm)
    elif model_name == 'vgg':
        model = vgg11()
    elif model_name == 'fc':
        if dataset_type == 'mnist':
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(784, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 10, bias=True),
                        )            
        else: #cifar10
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(3072, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 10, bias=True),
                        )
    else:
        raise ValueError("Unknown model")
    return model


def batch_power_iteration_evl(A, num_simulations, u=None):
    EPS = 1e-24
    if u is None:
        u = torch.randn((A.size()[0],A.size()[1],1)).cuda()

    B=A.transpose(1,2)

    for i in range(num_simulations):
        u1 = B.bmm(u)
        u1_norm = u1.norm(2, dim=(1,2))

        v = u1 / (u1_norm.view(-1,1,1)+EPS)

        u_tmp = u

        v1 = A.bmm(v)

        v1_norm = v1.norm(2, dim=(1,2))
        u = v1 / (v1_norm.view(-1,1,1)+EPS)

        if (u-u_tmp).norm(2,dim=(1,2)).max()<1e-5 or (i+1)==num_simulations:
            break
    
    output = u.transpose(1,2).bmm(A).bmm(v).reshape(-1,1)
    return output

def batch_app_evl(p):
    p1 = p[:,0]
    p2 = p[:,1]
    p3 = p[:,2]

    bb = (p1**2+p2**2-p1-p2)/2

    cc = p1*p2*(1-p1-p2)
    sol = -bb+torch.sqrt(bb**2-cc)
    
    return sol