from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score
from sklearn.svm import LinearSVC

import torch
import torch.utils.data as data_utils
from torch.nn.parameter import Parameter
from torch.optim import Optimizer

from utils import power_method, nuclear_projection, frobenius_projection
from sklearn.utils.extmath import randomized_svd
import numpy as np



class Quadratic(torch.nn.Module):
    def __init__(self, d, A=None, b=None, c=None):
        super(Quadratic, self).__init__()
        if A is None:
            self.A = Parameter(torch.Tensor(d, d).zero_())
        else:
            self.A = Parameter(A)
        if b is None:
            self.b = Parameter(torch.Tensor(d).zero_())
        else:
            self.b = Parameter(b)
        if c is None:
            self.c = Parameter(torch.Tensor(1).zero_())
        else:
            self.c = Parameter(c)
            
    def forward(self, x):
        z = torch.sum(x * torch.mm(x, self.A.T), dim=1)
        z += torch.mv(x, self.b)
        z += self.c
        return z
                
class FW_Tr(Optimizer):
    """
    From "Stochastic Conditional Gradient Methods:
    From Convex Minimization to Submodular Maximization", Mokhtari et al.
    params:
    - lr: learning rate for the  linear part (unconstrained)
    - lam: radius of the nuclear norm ball for the quadratic part
    """
    def __init__(self, params, lr, lam):
        self.d = 0
        defaults = dict(lr=lr, lam=lam)
        super(FW_Tr, self).__init__(params, defaults)
        
    def step(self, it):
        lr = self.param_groups[0]['lr']
        lam = self.param_groups[0]['lam']
        for i, p in enumerate(self.param_groups[0]['params']):
            d_p = p.grad.data
            if i == 0:
                gamma = 2 / (it + 8)
                rho = 4 / (it+8)**(2/3)
                self.d = (1-rho) * self.d + rho * np.array(d_p)
                u, s, v = randomized_svd(self.d, n_components=1)
                u = torch.Tensor(u)
                v = torch.Tensor(v)
                d = u.shape[0]
                S = -lam * torch.matmul(u.reshape(d,1),v.reshape(1,d)) # = LMO
                p.data.add_(gamma, S - p.data)
            else:
                p.data.add_(-lr, d_p)
                
class PGD(Optimizer):
    def __init__(self, params, lr, lam, norm = 'nuc'):
        self.norm = norm
        defaults = dict(lr=lr, lam=lam)
        super(PGD, self).__init__(params, defaults)
        
    def step(self):
        lr = self.param_groups[0]['lr']
        lam = self.param_groups[0]['lam']
        for i, p in enumerate(self.param_groups[0]['params']):
            d_p = p.grad.data
            p.data.add_(-lr, d_p)
            #if i == 0:
            #    print(np.linalg.norm(p.data, ord='nuc'))
            if i == 0: # and np.linalg.norm(p.data, ord='nuc') > lam:
                #print("Project")
                if self.norm == 'nuc':
                    p.data = torch.tensor(nuclear_projection(p.data, lam))
                elif self.norm == 'fro':
                    p.data = torch.tensor(frobenius_projection(p.data, lam))
                
                
                


class Quadratic(torch.nn.Module):
    def __init__(self, d, A=None, b=None, c=None):
        super(Quadratic, self).__init__()
        if A is None:
            self.A = Parameter(torch.Tensor(d, d).zero_())
        else:
            self.A = Parameter(A)
        if b is None:
            self.b = Parameter(torch.Tensor(d).zero_())
        else:
            self.b = Parameter(b)
        if c is None:
            self.c = Parameter(torch.Tensor(1).zero_())
        else:
            self.c = Parameter(c)
            
    def forward(self, x):
        z = torch.sum(x * torch.mm(x, self.A.T), dim=1)
        z += torch.mv(x, self.b)
        z += self.c
        return z
                
                
class FW_Tr_naive(Optimizer):
    """
    Naive extension of FW to stochastic setting (has been shown to diverge in some cases
    params:
    - lr: learning rate for the  linear part (unconstrained)
    - lam: radius of the nuclear norm ball for the quadratic part
    """
    def __init__(self, params, lr, lam):
        defaults = dict(lr=lr, lam=lam)
        super(FW_Tr, self).__init__(params, defaults)
        
    def step(self, it):
        lr = self.param_groups[0]['lr']
        lam = self.param_groups[0]['lam']
        for i, p in enumerate(self.param_groups[0]['params']):
            d_p = p.grad.data
            gamma = 2 / (it + 2)
            if i == 0:
                u, s, v = randomized_svd(np.array(d_p), n_components=1)
                u = torch.Tensor(u)
                v = torch.Tensor(v)
                d = u.shape[0]
                S = -lam * torch.matmul(u.reshape(d,1),v.reshape(1,d)) # = LMO
                p.data.add_(gamma, S - p.data)
            else:
                p.data.add_(-lr, d_p)

class HingeLoss(torch.nn.Module):

    def __init__(self):
        super(HingeLoss, self).__init__()

    def forward(self, output, target):
        hinge_loss = torch.max(1 - torch.mul(output, target), torch.zeros(len(target)))
        return torch.sum(hinge_loss)

    def forward(self, output, target):
        hinge_loss = torch.max(1 - torch.mul(output, target), torch.zeros(len(target)))
        return torch.sum(hinge_loss)

class QuadraticNuclear(BaseEstimator):
    def __init__(self, lr, lam, optimizer="PGD", norm='nuc'):
        self.lr = lr
        self.lam = lam
        self.norm = norm
        self.optimizer = optimizer
        self.classifier = None
        
    def fit(self, X, y, n_epoch=50, batch_size=10):
        d=len(X[0])
        train = data_utils.TensorDataset(torch.tensor(X).float(), torch.tensor(y).long())
        trainloader = data_utils.DataLoader(train, batch_size=batch_size, shuffle=False)
        self.classifier = Quadratic(d)
        criterion = HingeLoss()
        if self.optimizer == "FW":
            optimizer = FW_Tr(self.classifier.parameters(), lr=self.lr, lam=self.lam)
        elif self.optimizer == "SGD":
            optimizer = torch.optim.SGD(self.classifier.parameters(), lr=self.lr)
        elif self.optimizer == "PGD":
            optimizer = PGD(self.classifier.parameters(), lr=self.lr, lam=self.lam, norm=self.norm)
        
        step = 0
        for epoch in range(n_epoch):  # loop over the dataset multiple times
            #print("Classification train error: {}".format(((self.classifier(torch.tensor(train).float()).sign() + 1)/2 - torch.tensor(y)).abs().sum()/float(len(train))))
            for i, data in enumerate(trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.classifier(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                if self.optimizer == "FW":
                    optimizer.step(step)
                else:
                    optimizer.step()
                step += 1
        #print('Finished Training')
        
    def predict(self, X):
        return self.classifier(torch.tensor(X).float()).sign() 
    def score(self, X, y):
        return 1 - (self.predict(X) - torch.tensor(y)).abs().sum().item()/(2*float(len(X)))
