import numpy as np
import pandas as pd

from dwls import DWLS
from rwls import RWLS
from MRCpy import CMRC
# Import the datasets
from datasets.load import *
from sklearn.preprocessing import StandardScaler    

def dirichlet_shift(X, Y, dataName, parameter_alpha, X_test = None, Y_test = None):
        n_classes = len(np.unique(Y))     
        proportions_train = np.random.dirichlet(parameter_alpha, size=1)
        proportions_train = proportions_train.flatten()
        proportions_test = np.ones_like(proportions_train) / n_classes

        mask = proportions_train < 0.05   
        num_adjusted = np.sum(mask)   
        proportions_train[mask] = 0.05
        remaining_total = 1 - 0.05 * num_adjusted
        proportions_train[~mask] *= remaining_total / np.sum(proportions_train[~mask])

        mask = proportions_test < 0.05   
        num_adjusted = np.sum(mask)   
        proportions_test[mask] = 0.05
        remaining_total = 1 - 0.05 * num_adjusted
        proportions_test[~mask] *= remaining_total / np.sum(proportions_test[~mask])

        if dataName == '20News_500features':
            ntr = 300
            nte = 300
        elif dataName == 'redwine':
            ntr = 100
            nte = 100

        ntr_y = [round(ntr * p) for p in proportions_train]
        nte_y = [round(nte * p) for p in proportions_test]

        yTr = []
        yPte = []
        yTe = []    
        for i in range(n_classes):
            yTr.extend([i] * ntr_y[i])
            yPte.extend([i] * nte_y[i])
            yTe.extend([i] * nte_y[i])

        xTr = []
        xPte = []
        xTe = []

        if dataName == '20News_500features':
            for i in range(n_classes):
                class_indices_tr = np.where(Y == i)[0]
                class_indices_te = np.where(Y_test == i)[0]
            
                xTr.extend(X[class_indices_tr[:ntr_y[i]]])
                xTe.extend(X_test[class_indices_te[:nte_y[i]]])
                xPte.extend(X_test[class_indices_tr[nte_y[i]:(2 * nte_y[i])]])
            
        else:
            for i in range(n_classes):
                class_indices = np.where(Y == i)[0]
                
                xTr.extend(X[class_indices[:ntr_y[i]]])
                xPte.extend(X[class_indices[ntr_y[i]:ntr_y[i] + nte_y[i]]])
                xTe.extend(X[class_indices[ntr_y[i] + nte_y[i]:ntr_y[i] + 2 * nte_y[i]]])
            
        xTr = np.array(xTr)
        xPte = np.array(xPte)
        xTe = np.array(xTe)

        yTr = np.concatenate([np.full(ntr_y[i], i) for i in range(n_classes)])
        yPte = np.concatenate([np.full(nte_y[i], i) for i in range(n_classes)])
        yTe = np.concatenate([np.full(nte_y[i], i) for i in range(n_classes)])

        return xTr, yTr, xPte, yPte, xTe, yTe

# Data sets
loaders = [load_20News_500features, load_redwine]
dataName = ["20News_500features", "redwine"]

rep_max = 20

columns = ['dataset', 'iteration', 'method', 'error']
results = pd.DataFrame(columns=columns)

Error1 = np.zeros((len(dataName),rep_max))
Error2 = np.zeros((len(dataName),rep_max))
Error3 = np.zeros((len(dataName),rep_max))
Error4 = np.zeros((len(dataName),rep_max))
Error5 = np.zeros((len(dataName),rep_max))
Error6 = np.zeros((len(dataName),rep_max))
Error7 = np.zeros((len(dataName),rep_max))

for j, load in enumerate(loaders):

    # Loading the dataset
    if dataName[j] == '20News_500features':
        X_train, Y_train, X_test, Y_test = load()
        n_classes = len(np.unique(Y_train))
    elif dataName[j] == 'redwine':
        X, Y = load()
        mask = (Y != 0) & (Y != 1) & (Y != 5)
        X = X[mask]
        Y_filtered = Y[mask]
        # Rearrange classes: 2 -> 0, 3 -> 1, 4 -> 2
        class_mapping = {2: 0, 3: 1, 4: 2}
        Y = np.array([class_mapping[label] for label in Y_filtered])
        n_classes = len(np.unique(Y))
    
   
    for rep in range(rep_max):
        
        if dataName[j] == '20News_500features':
            scaler = StandardScaler()
        else:
            scaler = StandardScaler()
            X_normalized = scaler.fit_transform(X)

        parameter_alpha = 0.1 * np.ones(n_classes)
        if dataName[j] == '20News_500features':
            xTr, yTr, xPte, yPte, xTe, yTe = dirichlet_shift(X_train, Y_train, dataName[j], parameter_alpha, X_test, Y_test)
        else:
            xTr, yTr, xPte, yPte, xTe, yTe = dirichlet_shift(X, Y, dataName[j], parameter_alpha)
            
        ntr = xTr.shape[0]
        nte = xTe.shape[0]

        ptr_y = np.zeros(n_classes)
        pte_y = np.zeros(n_classes)
        for k in range(n_classes):
            ptr_y[k] = np.sum(yTr == k) / ntr
        for k in range(n_classes):
            pte_y[k] = np.sum(yTe == k) / nte

        #No Adaptation Method
        clf = CMRC(loss = '0-1', phi = 'linear', fit_intercept = True, s = 0, deterministic=True)
        clf.fit(xTr, yTr, xTe)
        Error1[j,rep] = clf.error(xTe, yTe)

        #TarS Method
        clf2 = RWLS(loss = '0-1', phi = 'linear', deterministic=True)
        clf2.fit(xTr, yTr, xPte)
        Error2[j,rep] = clf2.error(xTe, yTe)

        #BBSE Method
        clf3 = RWLS(loss = '0-1', phi = 'linear', beta_method='BBSE', deterministic=True)
        clf3.fit(xTr, yTr, xPte)
        Error3[j,rep] = clf3.error(xTe, yTe)

        #RLLS Method
        clf4 = RWLS(loss = '0-1', phi = 'linear', beta_method='RLLS', deterministic=True)
        clf4.fit(xTr, yTr, xPte)
        Error4[j,rep] = clf4.error(xTe, yTe)
   
        #DW-LS Method
        clf5 = DWLS(loss = '0-1', phi = 'linear', deterministic=True)
        clf5.fit(xTr, yTr, xPte)          
        Error5[j,rep] = clf5.error(xTe, yTe)

        #Reweighted Method using Exact Probabilities
        clf6 = RWLS(loss = '0-1', phi = 'linear', weights_beta  = pte_y / ptr_y, deterministic=True)
        clf6.fit(xTr, yTr, xTe)
        Error6[j,rep] = clf6.error(xTe, yTe)

        #DW Method using Exact Probabilities
        Ds = 1 / (1-np.arange(0, 1, 0.1))**2
        Cs = np.max(pte_y / ptr_y) / np.sqrt(Ds)
        n_Cs = len(Cs)
        RU = np.zeros(n_Cs)
        for i in range(n_Cs):
            beta_ = np.minimum(pte_y / ptr_y, Cs[i] * np.ones(n_classes))          
            alpha_ = np.minimum(Cs[i] * ptr_y / pte_y, np.ones(n_classes))
            clf7 = DWLS(loss = '0-1', phi = 'linear', weights_alpha = alpha_, weights_beta = beta_, deterministic=True)
            clf7.fit(xTr, yTr, xTe)
            RU[i] = clf7.upper_
        ii = np.argmax(RU)
        beta_ = np.minimum(pte_y / ptr_y, Cs[ii] * np.ones(n_classes))
        alpha_ = np.minimum(Cs[ii] * ptr_y / pte_y, np.ones(n_classes))
        clf7 = DWLS(loss = '0-1', phi = 'linear', weights_alpha = alpha_, weights_beta = beta_, deterministic=True)
        clf7.fit(xTr, yTr, xTe)            
        Error7[j,rep] = clf7.error(xTe, yTe)

    
        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'No_Adapt.\'',
                   'error': Error1}
        results.loc[len(results)] = new_row
        
        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'Exact_RW\'',
                   'error': Error6}
        results.loc[len(results)] = new_row

        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'Exact_DW\'',
                   'error': Error7}
        results.loc[len(results)] = new_row

        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'TarS\'',
                   'error': Error2}
        results.loc[len(results)] = new_row

        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'BBSE\'',
                   'error': Error3}
        results.loc[len(results)] = new_row

        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'RLLS\'',
                   'error': Error4}
        results.loc[len(results)] = new_row

        new_row = {'dataset': dataName[j],
                   'iteration' : rep,
                   'method' : '\'DW-LS\'',
                   'error': Error5}
        results.loc[len(results)] = new_row

    print(dataName[j])
    print(f"Mean Error and Std of No Adapt.: {np.mean(Error1[j, :]):.2f} ± {np.std(Error1[j, :]):.2f}")
    print(f"Mean Error and Std of Exact Reweighted: {np.mean(Error6[j, :]):.2f} ± {np.std(Error6[j, :]):.2f}")
    print(f"Mean Error and Std of Exact DW: {np.mean(Error7[j, :]):.2f} ± {np.std(Error7[j, :]):.2f}")
    print(f"Mean Error and Std of TarS: {np.mean(Error2[j, :]):.2f} ± {np.std(Error2[j, :]):.2f}")
    print(f"Mean Error and Std of BBSE: {np.mean(Error3[j, :]):.2f} ± {np.std(Error3[j, :]):.2f}")
    print(f"Mean Error and Std of RLLS: {np.mean(Error4[j, :]):.2f} ± {np.std(Error4[j, :]):.2f}")
    print(f"Mean Error and Std of DW-LS: {np.mean(Error5[j, :]):.2f} ± {np.std(Error5[j, :]):.2f}")