import torch
from utils import genDataLinear

class linearModel:
    def __init__(self, p):
        self.p = p
        self.betaBootstrap = None

    def fit(self, X, Y):
        # fit the regression
        if X.shape[1] != self.p:
            return "wrong dim"
        self.beta, _ = torch.solve(torch.mm(X.T, Y), torch.mm(X.T, X))
        return self.beta

    def predictGivenBeta(self, X, beta):
        if X.shape[1] != beta.shape[0]:
            return "wrong dim"
        return torch.mm(X, beta)

    def bootstrap(self, X, Y, numBootStrapSample):
        self.betaBootstrap = torch.zeros(self.p, numBootStrapSample)
        n = X.shape[0]
        for i in range(numBootStrapSample):
            bsRandomIdx = torch.multinomial(torch.ones(n)/(1. * n), n, replacement=True)
            X_bs = X[bsRandomIdx, :]
            Y_bs = Y[bsRandomIdx, :]
            self.betaBootstrap[:, i] = torch.solve(torch.mm(X_bs.T, Y_bs), torch.mm(X_bs.T, X_bs))[0].flatten()
        return self.betaBootstrap

    def centoridBootstrap(self, X, Y, numCentroidBootStrapSample, epochs=2000, lr=0.01):
        if self.betaBootstrap is not None:
            self.betaBootstrapCentroid = self.betaBootstrap.clone()
            self.centroidWeight = torch.ones(numCentroidBootStrapSample)
            n = X.shape[0]
        else:
            self.betaBootstrapCentroid = torch.zeros(self.p, numCentroidBootStrapSample)
            self.centroidWeight = torch.ones(numCentroidBootStrapSample)

            # init the centroid
            n = X.shape[0]
            for i in range(numCentroidBootStrapSample):
                bsRandomIdx = torch.multinomial(torch.ones(n)/(1. * n), n, replacement=True)
                X_bs = X[bsRandomIdx, :]
                Y_bs = Y[bsRandomIdx, :]
                self.betaBootstrapCentroid[:, i] = torch.solve(torch.mm(X_bs.T, Y_bs), torch.mm(X_bs.T, X_bs))[0].flatten()

        # update
        for _ in range(epochs):
            lr_ = lr
            if _ >= 300:
                lr_ = lr/10
            if _ >= 700:
                lr_ = lr/100

            bsRandomIdx = torch.multinomial(torch.ones(n) / (1. * n), n, replacement=True)
            X_bs = X[bsRandomIdx, :]
            Y_bs = Y[bsRandomIdx, :]
            pred = self.predictGivenBeta(X_bs, self.betaBootstrapCentroid)
            loss = torch.mean((pred - Y_bs) ** 2, dim=0)
            centerIdx = loss.argmin()

            self.centroidWeight[centerIdx] += 1
            beta = self.betaBootstrapCentroid[:, [centerIdx]]
            grad = torch.mm(X_bs.T, (torch.mm(X_bs, beta) - Y_bs)).flatten()
            self.betaBootstrapCentroid[:, centerIdx] -= lr_ * grad
            #"""

        # normalize centroidWeight
        self.centroidWeight /= torch.sum(self.centroidWeight)
        self.betaBootstrap = None
        return self.betaBootstrapCentroid, self.centroidWeight
