classdef HMOFSFIAK < ALGORITHM
% <multi> <integer/binary> 
% Multi-objective feature selection: fast initialization and adaptive knn

    methods
        function main(Algorithm,Problem)
            %% Population initialization
            DataName = [class(Problem) '1.mat'];
            Population1 = InitializePopulation(Problem);
            FrontNo1 = NDSort(Population1.objs, inf);
            CrowdDis1 = CrowdingDistance(Population1.objs,FrontNo1);
            Mask1 = Population1.decs;

            Mask2 = Population1.decs;
            [PF,PS] = JSEMO_MainCode(DataName,1,Problem.N,Mask2);
            PF(:,1) = 1 - PF(:,1);
            PF(:,2) = sum(PS,2)./Problem.D;
            Population2 = SOLUTION(PS,PF,zeros(Problem.N,1));
            FrontNo2 = NDSort(Population2.objs, inf);
            CrowdDis2 = CrowdingDistance(Population2.objs,FrontNo2);
	    g = 0;
            initialized = 0;
            mark = 0;

            %% Optimization
            while Algorithm.NotTerminated(Population1)
                if g == 0
                    MatingPool11 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                    OffMask11 = Operator1(Population1(MatingPool11),Mask1(MatingPool11,:));
                    [rbm1,allZero1,allOne1] = ModelTraining(Mask1);
                    MatingPool12 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                    OffMask12 = OperatorRBM(Mask1(MatingPool12,:),rbm1,allZero1,allOne1);
                    OffMask1 = [OffMask11;OffMask12];
                    OffMask1 = unique(OffMask1, 'rows');
                    Offspring1 = Problem.Evaluation(OffMask1);
                    Population1 = EnvironmentalSelectionDAEA([Population1, Offspring1],Problem.N);
                    FrontNo1 = NDSort(Population1.objs, inf);
                    CrowdDis1 = CrowdingDistance(Population1.objs,FrontNo1);
                    Mask1 = Population1.decs;

                    MatingPool21 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo2,-CrowdDis2);
                    OffMask21 = Operator1(Population2(MatingPool21),Mask2(MatingPool21,:));
                    [rbm2,allZero2,allOne2] = ModelTraining(Mask2);
                    MatingPool22 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo2,-CrowdDis2);
                    OffMask22 = OperatorRBM(Mask2(MatingPool22,:),rbm2,allZero2,allOne2);
                    OffMask2 = [OffMask21;OffMask22];
                    OffMask2 = unique(OffMask2, 'rows');
                    [OffPF,OffPS] = JSEMO_MainCode(DataName,1,2*Problem.N,OffMask2);
                    OffPF(:,1) = 1 - OffPF(:,1);
                    OffPF(:,2) = sum(OffPS,2)./Problem.D;
                    Offspring2 = SOLUTION(OffPS,OffPF,zeros(2*Problem.N,1));
                    Population2 = EnvironmentalSelectionDAEA([Population2, Offspring2],Problem.N);
                    FrontNo2 = NDSort(Population2.objs, inf);
                    CrowdDis2 = CrowdingDistance(Population2.objs,FrontNo2);
                    Mask2 = Population2.decs;

                    Problem.FE = Problem.FE + length(Offspring2);

		    g = 1;
                else
                    if initialized == 0
                        initialized = 1;
                        Pop1 = Population1.objs;
                        mce1 = min(Pop1(:,1));
                        Pop2 = Population2.objs;
                        mce2 = min(Pop2(:,1));
                        if mce1 <= mce2
                            % use Population1 and knn
                            mark = 1;
                        else
                            % use Population2 and knn2w
                            Population1 = Population2;
                            FrontNo1 = FrontNo2;
                            CrowdDis1 = CrowdDis2;
                            Mask1 = Mask2;
                            mark = 2;
                        end
                    end

                    if mark == 1
                        MatingPool11 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                        OffMask11 = Operator1(Population1(MatingPool11),Mask1(MatingPool11,:));
                        [rbm1,allZero1,allOne1] = ModelTraining(Mask1);
                        MatingPool12 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                        OffMask12 = OperatorRBM(Mask1(MatingPool12,:),rbm1,allZero1,allOne1);
                        OffMask1 = [OffMask11;OffMask12];
                        OffMask1 = unique(OffMask1, 'rows');
                        Offspring1 = Problem.Evaluation(OffMask1);
                        Population1 = EnvironmentalSelectionDAEA([Population1, Offspring1],Problem.N);
                        FrontNo1 = NDSort(Population1.objs, inf);
                        CrowdDis1 = CrowdingDistance(Population1.objs,FrontNo1);
                        Mask1 = Population1.decs;

                        if Problem.FE >= Problem.maxFE
                            Problem.boolTrain = 0;
                            PopDec = Population1.decs;
                            Population1 = Problem.Evaluation(PopDec);
                            [FrontNo,~] = NDSort(Population1.objs,Population1.cons,inf);
                            Population1 = Population1(FrontNo==1);
                        end
                    else
                        MatingPool11 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                        OffMask11 = Operator1(Population1(MatingPool11),Mask1(MatingPool11,:));
                        [rbm1,allZero1,allOne1] = ModelTraining(Mask1);
                        MatingPool12 = TournamentSelection(2,ceil(Problem.N/2)*2,FrontNo1,-CrowdDis1);
                        OffMask12 = OperatorRBM(Mask1(MatingPool12,:),rbm1,allZero1,allOne1);
                        OffMask1 = [OffMask11;OffMask12];
                        OffMask1 = unique(OffMask1, 'rows');
                        [OffPF,OffPS] = JSEMO_MainCode(DataName,1,2*Problem.N,OffMask1);
                        OffPF(:,1) = 1 - OffPF(:,1);
                        OffPF(:,2) = sum(OffPS,2)./Problem.D;
                        Offspring1 = SOLUTION(OffPS,OffPF,zeros(2*Problem.N,1));
                        Problem.FE = Problem.FE + length(Offspring1);
                        Population1 = EnvironmentalSelectionDAEA([Population1, Offspring1],Problem.N);
                        FrontNo1 = NDSort(Population1.objs, inf);
                        CrowdDis1 = CrowdingDistance(Population1.objs,FrontNo1);
                        Mask1 = Population1.decs;

                        if Problem.FE >= Problem.maxFE
                            [PF,PS] = JSEMO_MainCode(DataName,1,Problem.N,Mask1);
                            PF(:,1) = 1 - PF(:,1);
                            PF(:,2) = sum(PS,2)./Problem.D;
                            Population1 = SOLUTION(PS,PF,zeros(Problem.N,1));
                            [FrontNo,~] = NDSort(Population1.objs,Population1.cons,inf);
                            Population1 = Population1(FrontNo==1);
                        end
                    end
                end
            end
        end
    end
end

function Population = InitializePopulation(Problem)
    T = Problem.N * 3;
        TrainX = Problem.oTrainX; TrainY = Problem.oTrainY;
        % TrainX = double(TrainX);
        % TrainY = double(TrainY);
        % 初始化互信息结果
        miResults = zeros(1, Problem.D);
        % 计算每一列与最后一列的互信息
        for col = 1:Problem.D
            miResults(col) = mutualinfo(TrainX(:, col), TrainY);
        end
        Pop = zeros(Problem.N, Problem.D);
        for i = 1 : Problem.N
            k = randperm(T, 1);
            if k==600
                k=599;
            end
            j = TournamentSelection(ceil(600/k),k,-miResults);
            Pop(i, j) = 1;
        end
        Population = Problem.Evaluation(Pop);
end

function mi = mutualinfo(x, y)
    % Ensure the vectors are column vectors
    x = x(:);
    y = y(:);
    
    % Compute joint histogram with appropriate bin edges
    [jointHist, xEdges, yEdges] = histcounts2(x, y, 'Normalization', 'probability');
    
    % Marginal probabilities
    xProb = sum(jointHist, 2);
    yProb = sum(jointHist, 1);
    
    % Compute mutual information
    mi = 0;  % Initialize mutual information
    for i = 1:size(jointHist, 1)
        for j = 1:size(jointHist, 2)
            if jointHist(i, j) > 0
                mi = mi + jointHist(i, j) * log(jointHist(i, j) / (xProb(i) * yProb(j)));
            end
        end
    end
end