import torch
import os
from fair_PCA import *
from base_processor import BaseProcessor

class FairPCAProcessor(BaseProcessor):

    def modify_embedding(self, pipe, prompt_embeds, pooled_prompt_embeds, usermode = {}, exp_dir = "."):
        if "remove" in usermode:
            if not hasattr(pipe, "fpca"):
                pipe.fpcas = calc_projection_matrix(exp_dir, usermode)
                for fpca in pipe.fpcas:
                    fpca.UUT = torch.tensor(fpca.UUT, dtype=pooled_prompt_embeds.dtype).to(pooled_prompt_embeds.device)
                    fpca.nzTXT = torch.tensor(fpca.nzTXT, dtype=pooled_prompt_embeds.dtype).to(pooled_prompt_embeds.device)
                #self.UUT = torch.tensor(self.UUT, dtype=pooled_prompt_embeds.dtype).to(pooled_prompt_embeds.device)
                #self.fpca.transformation_matrix = torch.tensor(self.fpca.transformation_matrix, dtype=pooled_prompt_embeds.dtype).to(pooled_prompt_embeds.device)
                #self.fpca.transformation_matrix_standard_PCA = torch.tensor(self.fpca.transformation_matrix_standard_PCA, dtype=pooled_prompt_embeds.dtype).to(pooled_prompt_embeds.device)
            prompt_embeds1, prompt_embeds2 = prompt_embeds[:, :, :len(pipe.fpcas[0].UUT)], prompt_embeds[:, :, len(pipe.fpcas[0].UUT):]
            if "renorm" in usermode:
                pooled_prompt_embeds_norm = pooled_prompt_embeds.norm(dim=1, keepdim=True)
                prompt_embeds1_norm, prompt_embeds2_norm = prompt_embeds1.norm(dim=1, keepdim=True), prompt_embeds2.norm(dim=1, keepdim=True)
            for fpca in pipe.fpcas:
                pooled_prompt_embeds = fpca.transform(pooled_prompt_embeds)
                prompt_embeds1 = fpca.transform(prompt_embeds1)
                prompt_embeds2 = fpca.transform(prompt_embeds2)
            if "renorm" in usermode:
                pooled_prompt_embeds = pooled_prompt_embeds / pooled_prompt_embeds.norm(dim=1, keepdim=True) * pooled_prompt_embeds_norm
                prompt_embeds1 = prompt_embeds1 / prompt_embeds1.norm(dim=1, keepdim=True) * prompt_embeds1_norm
                prompt_embeds2 = prompt_embeds2 / prompt_embeds2.norm(dim=1, keepdim=True) * prompt_embeds2_norm


            prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)

        if "cmovedbg" in usermode:
            data = torch.load(f"{exp_dir}/extracted_features.pt")
            cmover = CentroidMover(data, usermode, pooled_prompt_embeds.device)
            prompt_embeds1, prompt_embeds2 = prompt_embeds[:, :, :cmover.dim], prompt_embeds[:, :, cmover.dim:]
            for demo in usermode["protect"]:
                pooled_prompt_embeds = cmover.transform(demo, pooled_prompt_embeds)
                prompt_embeds1 = cmover.transform(demo, prompt_embeds1)
                prompt_embeds2 = cmover.transform(demo, prompt_embeds2)
            prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)

        return prompt_embeds, pooled_prompt_embeds


def calc_projection_matrix(exp_dir, usermode):
    data = torch.load(f"{exp_dir}/extracted_features.pt")
    tradeoff = usermode.get("tradeoff", 0.4) # No use
    hdim = usermode.get("hdim", 600)
    print(hdim)
    if "cross" in usermode:
        fpcas = [FairPCA(target_dim = hdim, standardize = False, tradeoff_param = tradeoff)]
    elif "kernel" not in usermode:
        fpcas = [FairPCA(target_dim = hdim, standardize = False, tradeoff_param = tradeoff) for _ in range(len(data))]
    else:
        fpcas = [FairKernelPCA(target_dim  = hdim, kernel = "rbf", degree_kernel = 2, gamma_kernel = "auto", standardize  = False, tradeoff_param  = tradeoff) for _ in range(len(data))]

    for fpca in fpcas:
        fpca.usermode = usermode
    if len(data.keys()) == 1:
        protect = list(data.keys())[0]
        if "rndsample" in usermode:
            calc_projection_matrix_rndsample(data, fpca)
        if len(data[protect].keys()) == 2:
            calc_projection_matrix_sg(data[protect], fpca)
        else:
            calc_projection_matrix_mg(data[protect], fpca)
    elif "cross" in usermode:
        calc_projection_matrix_mgmd_cross(data, fpca)
    else:
        calc_projection_matrix_mgmd(data, fpcas)
    for fpca, protect in zip(fpcas, data):
        fpca.get_emperical(data, usermode)
        fpca.protect = protect
    return fpcas
        

def calc_projection_matrix_sg(data, fpca):    
    keys = list(data.keys())
    X = torch.cat(list(data.values()), dim=0)
    z = torch.cat((torch.zeros(len(data[keys[0]])), torch.ones(len(data[keys[1]])))).to(X.dtype)
    fpca.fit(X, z)


def calc_projection_matrix_mg(data, fpca):
    keys = list(data.keys())
    X = torch.cat(list(data.values()), dim=0)
    Z = torch.zeros(X.shape[0], len(keys)).type_as(X)
    st = 0
    for i, k in enumerate(keys):
        Z[st : st + data[k].shape[0], i] = 1.
        st += data[k].shape[0]
    fpca.fit_mg(X, Z)


def calc_projection_matrix_mgmd_v1(data, fpca):
    Xs, Zs = [], []
    for protect in data:
        keys = list(data[protect].keys())
        X = torch.cat(list(data[protect].values()), dim=0)
        Z = torch.zeros(X.shape[0], len(keys)).type_as(X)
        st = 0
        for i, k in enumerate(keys):
            Z[st : st + data[protect][k].shape[0], i] = 1.
            st += data[protect][k].shape[0]
        Xs.append(X)
        Zs.append(Z)
    fpca.fit_mgmd(Xs, Zs)

def calc_projection_matrix_mgmd(data, fpcas):
    Xs, Zs = [], []
    for protect, fpca in zip(data, fpcas):
        keys = list(data[protect].keys())
        X = torch.cat(list(data[protect].values()), dim=0)
        Z = torch.zeros(X.shape[0], len(keys)).type_as(X)
        st = 0
        for i, k in enumerate(keys):
            Z[st : st + data[protect][k].shape[0], i] = 1.
            st += data[protect][k].shape[0]
        fpca.fit_mg(X, Z)

def calc_projection_matrix_mgmd_cross(data, fpca):
    from itertools import product
    Xs, Zs = [], []
    gid = 0
    cross = list(product(*[data[protect].keys() for protect in data]))
    for gid, comb in enumerate(cross):
        xss, zss = [], []
        for pi, protect in enumerate(data):
            xss.append(data[protect][comb[pi]])
            zss.append(torch.ones(data[protect][comb[pi]].shape[0]) * gid)
        Xs.append(torch.cat(xss, dim=0))
        Zs.append(torch.cat(zss, dim=0))
    X = torch.cat(Xs, dim=0)
    Zid = torch.cat(Zs, dim=0).long()
    Z = torch.zeros(X.shape[0], gid + 1).type_as(X)
    Z[torch.arange(X.shape[0]), Zid] = 1.
    fpca.fit_mg(X, Z)


def calc_projection_matrix_rndsample(data, fpca):
    Xs, Zs = [], []
    splitcnt = self.usermode["rndsample"] or 5
    for protect in data:
        keys = list(data[protect].keys())
        X = torch.cat(list(data[protect].values()), dim=0)
        Z = torch.zeros(X.shape[0], len(keys)).type_as(X)
        st = 0
        for i, k in enumerate(keys):
            Z[st : st + data[protect][k].shape[0], i] = 1.
            st += data[protect][k].shape[0]
        # Shuffle rows of X and Z
        indices = torch.randperm(X.shape[0])
        X = X[indices]
        Z = Z[indices]

        # Split into splitcnt
        split_size = X.shape[0] // splitcnt
        Xs_split = torch.split(X, split_size)
        Zs_split = torch.split(Z, split_size)

        # Append to Xs and Zs
        Xs.extend(Xs_split)
        Zs.extend(Zs_split)
    fpca.fit_mgmd(Xs, Zs)
