from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
import random
import numpy as np
from time import time
from u_evaluation import evaluate
from u_savedata import *
import warnings
warnings.filterwarnings('ignore')
from skmultilearn.problem_transform import LabelPowerset
# from FALS_gpu import TorchBinaryLogisticRegression

import torch.nn as nn
import torch

class TorchBinaryLogisticRegression:
    def __init__(self, lr=0.01, max_iter=100, tol=1e-4):
        self.lr = lr
        self.max_iter = max_iter
        self.tol = tol
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        # self._build_model()

    def _build_model(self, input_dim):
        self.model = nn.Linear(input_dim, 1).to(self.device)
        self.criterion = nn.BCEWithLogitsLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

    def fit(self, X, y):
        X = torch.tensor(X, dtype=torch.float32).to(self.device)
        y = torch.tensor(y, dtype=torch.float32).to(self.device)

        if self.model is None:
            input_dim = X.shape[1]
            self._build_model(input_dim)

        prev_loss = float('inf')
        for i in range(self.max_iter):
            self.model.train()
            self.optimizer.zero_grad()

            logits = self.model(X).squeeze()
            loss = self.criterion(logits, y)

            loss.backward()
            self.optimizer.step()

            if abs(prev_loss - loss.item()) < self.tol:
                print(f"[Converged] Epoch {i}, Loss = {loss.item():.6f}")
                break
            prev_loss = loss.item()
        else:
            print(f"[Max Iter Reached] Final Loss = {loss.item():.6f}")

    def predict_proba(self, X):
        self.model.eval()
        X = torch.tensor(X, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            logits = self.model(X).squeeze()
            probs = torch.sigmoid(logits)
        return torch.stack([1 - probs, probs], dim=1).cpu().numpy()
    

def base_cls(mod='dt'):
    if(mod=='svm'):
        return SVC(probability=True, tol=1e-4, cache_size=200, max_iter=1000)
    elif(mod=='sgd'):
        return SGDClassifier(loss='log_loss')
    elif(mod=='lr'):
        return LogisticRegression()
    elif(mod=='bayes'):
        return GaussianNB()
    elif(mod=='dt'):
        return DecisionTreeClassifier()
    elif(mod=='nn'):
        return MLPClassifier(tol=1e-4, max_iter=200)
    elif(mod=='forest'):
        return RandomForestClassifier()
    elif(mod=='torchlr'):
        return TorchBinaryLogisticRegression()
    else:
        return None

def fill1(Y):
    Y = np.array(Y)
    for j in range(np.shape(Y)[1]):
        if(np.sum(Y[:,j])==0):
            Y[0][j] = 1
    return Y

def filly(y):
    y = np.array(y)
    if(np.sum(y)==0):
        y[0] = 1
    return y

def randorder(Q):
    return np.array(random.sample(range(Q),Q))

def balanceorder(Y):
    order = np.argsort(np.sum(Y, 0))[::-1]
    return order

class Baser():
    def __init__(self, basemode='dt'):
        self.learner = base_cls(basemode)
    def fit(self,X,y,ins_weight=[]):
        self.output = -1
        if(np.sum(y)==len(y)):
            self.output = 1
        elif(np.sum(y)==0):
            self.output = 0
        else:
            if(len(ins_weight)==0):
                self.learner.fit(X,y)
            else:
                self.learner.fit(X,y,ins_weight)
    def predict_proba(self, Xt):
        if(self.output==-1):
            return self.learner.predict_proba(Xt)
        else:
            return np.zeros((len(Xt),2))+self.output

class BR():
    def __init__(self):
        self.baseLearner = []
        self.Q = 0
    def train(self,X,Y,idxs=[]):
        self.Q = np.shape(Y)[1]
        for j in range(self.Q):
            singleLearner = Baser()
            if(len(idxs)==0):
                singleLearner.fit(X,Y[:,j])
            else:
                idx = np.argwhere(idxs[j]).flatten()
                singleLearner.fit(X[idx],Y[:,j][idx])
            self.baseLearner.append(singleLearner)
    def test(self,Xt):
        prediction = []
        for j in range(self.Q):
            prediction_a = self.baseLearner[j].predict_proba(Xt)[:,1]
            prediction.append(prediction_a)
        return np.array(np.transpose(prediction))
    def test_a(self,Xt,k):
        prediction_a = self.baseLearner[k].predict_proba(Xt)[:,1]
        return prediction_a
