import argparse
import matplotlib
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random
import torch
import torch.nn as nn
import torch.optim as optim

from data_generate import data_generation
from data_loading import lawDataset
from pathlib import Path
from scipy.stats import gaussian_kde
from torch.utils.data import DataLoader

#from pylab import *

def GPA(S, R, K, dataset):
    return torch.mm(S, torch.Tensor(dataset.wGS).unsqueeze(1).cuda()) + torch.mm(R, torch.Tensor(dataset.wGR).unsqueeze(1).cuda()) + torch.tensor(dataset.wGK).cuda() * K + dataset.bG

def SAT(S, R, K, dataset):
    return torch.exp(torch.mm(S, torch.Tensor(dataset.wLS).unsqueeze(1).cuda()) + torch.mm(R, torch.Tensor(dataset.wLR).unsqueeze(1).cuda()) + torch.tensor(dataset.wLK).cuda() * K + dataset.bL)

def FYA(S, R, K, dataset):
    return torch.mm(S, torch.Tensor(dataset.wFS).unsqueeze(1).cuda()) + torch.mm(R, torch.Tensor(dataset.wFR).unsqueeze(1).cuda()) + torch.tensor(dataset.wFK).cuda() * K

class UFdecision(nn.Module):
    def __init__(self, input_dim):
        super(UFdecision, self).__init__()
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        out = self.layer(x)
        return out

def UFtrain(loader, vloader, model, loss_fn, optimizer, total_epochs, eta):
    losses = []
    valid_losses = []
    
    for __ in range(total_epochs):
        model.train()
        running_loss = 0
        for __, data in enumerate(loader):
            optimizer.zero_grad()
            R = data[:, 2 : 10].detach().clone().cuda()
            G = data[:, 511].unsqueeze(1).detach().clone().cuda()
            L = data[:, 512].unsqueeze(1).detach().clone().cuda()
            F = data[:, 513].unsqueeze(1).detach().clone().cuda()
            
            out = model(torch.cat([R, G, L], dim=1))
            loss = loss_fn(out, F)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        losses.append(running_loss / len(loader))
    
    plt.plot(losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/UF_train_loss.png")
    plt.close()

def UFtest(loader, model, loss_fn, eta):
    model.eval()
    
    original_unfairs = []
    unfairs = []
    running_loss = 0
    
    Y_current = []
    cY_current = []
    
    Y_prime = []
    cY_prime = []
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].detach().clone().cuda()
        R = data[:, 2 : 10].detach().clone().cuda()
        G = data[:, 511].unsqueeze(1).detach().clone().cuda()
        L = data[:, 512].unsqueeze(1).detach().clone().cuda()
        target = data[:, 513].unsqueeze(1).detach().clone().cuda()
        
        with torch.no_grad():
            out = model(torch.cat([R, G, L], dim=1))
            loss = loss_fn(out, target)
            running_loss += loss.item()
        
        o_fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        for k_index in range(500):
            K = data[:, 10 + k_index].unsqueeze(1).detach().clone().cuda()
            cK = K.detach().clone().cuda()
            
            #F_check = torch.normal(FYA(1 - S, R, K, loader.dataset), 1)
            F_check = FYA(1 - S, R, K, loader.dataset)
            #F = torch.normal(FYA(S, R, K, loader.dataset), 1)
            F = FYA(S, R, K, loader.dataset)
            o_fair += torch.abs(F - F_check).detach().clone()
            
            if i == 0:
                Y_current.append(F[0].item())
                cY_current.append(F_check[0].item())
            
            #G = torch.normal(GPA(S, R, K, loader.dataset), loader.dataset.sigma)
            G = GPA(S, R, K, loader.dataset)
            #L = torch.poisson(SAT(S, R, K, loader.dataset))
            L = SAT(S, R, K, loader.dataset)
            G.requires_grad = True
            L.requires_grad = True
            out = model(torch.cat([R, G, L], dim=1))
            gradient_dummy = torch.ones_like(out)
            out.backward(gradient=gradient_dummy)
            grad_K = loader.dataset.wGK * G.grad + loader.dataset.wLK * L.grad
            f_K = K + eta * grad_K
            #f_F = torch.normal(FYA(S, R, f_K, loader.dataset), 1)
            f_F = FYA(S, R, f_K, loader.dataset)
            
            #cG = torch.normal(GPA(1 - S, R, cK, loader.dataset), loader.dataset.sigma)
            cG = GPA(1 - S, R, cK, loader.dataset)
            #cL = torch.poisson(SAT(1 - S, R, cK, loader.dataset))
            cL = SAT(1 - S, R, cK, loader.dataset)
            cG.requires_grad = True
            cL.requires_grad = True
            cout = model(torch.cat([R, cG, cL], dim=1))
            cgradient_dummy = torch.ones_like(cout)
            cout.backward(gradient=cgradient_dummy)
            grad_cK = loader.dataset.wGK * cG.grad + loader.dataset.wLK * L.grad
            f_cK = cK + eta * grad_cK
            #f_cF = torch.normal(FYA(1 - S, R, f_cK, loader.dataset), 1)
            f_cF = FYA(1 - S, R, f_cK, loader.dataset)
            
            fair += torch.abs(f_F - f_cF).detach().clone()
            
            if i == 0:
                Y_prime.append(f_F[0].item())
                cY_prime.append(f_cF[0].item())
        
        o_fair = o_fair / 500
        fair = fair / 500
        
        original_unfairs.extend([o_fair[i].item() for i in range(o_fair.size()[0])])
        unfairs.extend([fair[i].item() for i in range(fair.size()[0])])
        
    original_unfairness = np.mean(original_unfairs)
    unfairness = np.mean(unfairs)
    loss = running_loss / len(loader)
    
    with open("UF_Y.pkl", "wb") as f:
        pickle.dump(Y_prime, f)
    with open("UF_cY.pkl", "wb") as f:
        pickle.dump(cY_prime, f)
    with open("current_Y.pkl", "wb") as f:
        pickle.dump(Y_current, f)
    with open("current_cY.pkl", "wb") as f:
        pickle.dump(cY_current, f)
    
    return original_unfairness, unfairness, loss

def UFexperiment(loaders, total_epochs, eta):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = UFdecision(input_dim=10).cuda()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    UFtrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = UFtest(test_loader, model, loss_fn, eta)
    print("original_unfairnss = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))

class CFdecision(nn.Module):
    def __init__(self, input_dim):
        super(CFdecision, self).__init__()
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        out = self.layer(x)
        return out

def CFtrain(loader, vloader, model, loss_fn, optimizer, total_epochs, eta):
    losses = []
    
    for __ in range(total_epochs):
        model.train()
        running_loss = 0
        for __, data in enumerate(loader):
            optimizer.zero_grad()
            K = data[:, 10 + random.randint(0, 500)].unsqueeze(1).detach().clone().cuda()
            S = data[:, 0 : 2].detach().clone().cuda()
            R = data[:, 2 : 10].detach().clone().cuda()
            target = data[:, 513].unsqueeze(1).detach().clone().cuda()
            
            out = model(torch.cat([R, K], dim=1))
            loss = loss_fn(out, target)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        losses.append(running_loss / len(loader))
    
    plt.plot(losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/CF_train_loss.png")
    plt.close()

def CFtest(loader, model, loss_fn, eta):
    model.eval()
    
    original_unfairs = []
    unfairs = []
    running_loss = 0
    
    Y_prime = []
    cY_prime = []
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].detach().clone().cuda()
        R = data[:, 2 : 10].detach().clone().cuda()
        target = data[:, 513].unsqueeze(1).detach().clone().cuda()
        
        o_fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        for k_index in range(500):
            K = data[:, k_index].unsqueeze(1).detach().clone().cuda()
            cK = K.detach().clone().cuda()
            
            with torch.no_grad():
                out = model(torch.cat([R, K], dim=1))
                loss = loss_fn(out, target)
                running_loss += loss.item()
                
                #F = torch.normal(FYA(S, R, K, loader.dataset), 1)
                F = FYA(S, R, K, loader.dataset)
                #F_check = torch.normal(FYA(1 - S, R, K, loader.dataset), 1)
                F_check = FYA(1 - S, R, cK, loader.dataset)
                
                o_fair += torch.abs(F - F_check).detach().clone()
        
            K.requires_grad = True
            out = model(torch.cat([R, K], dim=1))
            gradient_dummy = torch.ones_like(out)
            out.backward(gradient_dummy)
            f_K = K + eta * K.grad
            #f_F = torch.normal(FYA(S, R, f_K, loader.dataset), 1)
            f_F = FYA(S, R, f_K, loader.dataset)
            
            cK.requires_grad = True
            cout = model(torch.cat([R, cK], dim=1))
            cgradient_dummy = torch.ones_like(cout)
            cout.backward(cgradient_dummy)
            f_cK = cK + eta * cK.grad
            #f_cF = torch.normal(FYA(1 - S, R, f_cK, loader.dataset), 1)
            f_cF = FYA(1 - S, R, f_cK, loader.dataset)
            
            fair += torch.abs(f_F - f_cF).detach().clone()
            
            if i == 0:
                Y_prime.append(f_F[0].item())
                cY_prime.append(f_cF[0].item())
        
        o_fair = o_fair / 500
        fair = fair / 500
        
        original_unfairs.extend([o_fair[i].item() for i in range(o_fair.size()[0])])
        unfairs.extend([fair[i].item() for i in range(fair.size()[0])])
    
    original_unfairness = np.mean(original_unfairs)
    unfairness = np.mean(unfairs)
    loss = running_loss / (500 * len(loader))
    
    with open("CF_Y.pkl", "wb") as f:
        pickle.dump(Y_prime, f)
    with open("CF_cY.pkl", "wb") as f:
        pickle.dump(cY_prime, f)
    
    return original_unfairness, unfairness, loss

def CFexperiment(loaders, total_epochs, eta):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = CFdecision(input_dim=9).cuda()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    CFtrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = CFtest(test_loader, model, loss_fn, eta)
    print("original_unfairnss = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))

class DFdecision(nn.Module):
    def __init__(self, input_dim):
        super(DFdecision, self).__init__()
        self.w1 = torch.Tensor([0]).cuda()
        self.w2 = nn.Parameter(torch.randn(1))
        self.w3 = nn.Parameter(torch.randn(1))
        
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, cy, u):
        h1 = self.w1 * cy ** 2 + self.w2 * cy + self.w3
        h2 = self.layer(u)
        out = h1 + h2
        return out

def DFtrain(loader, vloader, model, loss_fn, optimizer, total_epoch, eta):
    losses = []
    
    for __ in range(total_epoch):
        model.train()
        running_loss = 0
        for __, data in enumerate(loader):
            optimizer.zero_grad()
            K = data[:, 10 + random.randint(0, 500)].unsqueeze(1).detach().clone().cuda()
            #K = torch.mean(data[:, 10 + K_index], dim=1).unsqueeze(1).detach().clone().cuda()
            S = data[:, 0 : 2].detach().clone().cuda()
            R = data[:, 2 : 10].detach().clone().cuda()
            F = data[:, 513].unsqueeze(1).detach().cuda()
            
            #F_check = torch.normal(FYA(1 - S, R, K, loader.dataset), 1)
            F_check = FYA(1 - S, R, K, loader.dataset)
            #F = torch.normal(FYA(S, R, K, loader.dataset), 1)
            #F = FYA(S, R, K, loader.dataset)
            
            out = model(F_check, torch.cat([R, K], dim=1))
            loss = loss_fn(out, F)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        losses.append(running_loss / len(loader))
    
    plt.plot(losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/DF_train_loss.png")
    plt.close()

def DFtest(loader, model, loss_fn, eta, ratio):
    model.eval()
    
    original_unfairs = []
    unfairs = []
    running_loss = 0
    
    Y_prime = []
    cY_prime = []
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].detach().clone().cuda()
        R = data[:, 2 : 10].detach().clone().cuda()
        target = data[:, 513].unsqueeze(1).detach().clone().cuda()
        
        o_fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        fair = torch.zeros(size=(S.size()[0], 1)).cuda()
        for K_index in range(500):
            K = data[:, 10 + K_index].unsqueeze(1).detach().clone().cuda()
            #K = torch.mean(data[:, 10 : 510], dim=1).unsqueeze(1).detach().clone().cuda()
            cK = K.detach().clone().cuda()
            
            with torch.no_grad():
                #F_check = torch.normal(FYA(1 - S, R, K, loader.dataset), 1)
                F_check = FYA(1 - S, R, K, loader.dataset)
                out = model(F_check, torch.cat([R, K], dim=1))
                loss = loss_fn(out, target)
                running_loss += loss.item()
                
                #F = torch.normal(FYA(S, R, K, loader.dataset), 1)
                F = FYA(S, R, K, loader.dataset)
                
                o_fair += torch.abs(F - F_check).detach().clone()
            
            K.requires_grad = True
            eplision = torch.randn(1).cuda()
            #F_check = FYA(1 - S, R, K, loader.dataset) + eplision
            F_check = FYA(1 - S, R, K, loader.dataset)
            out = model(F_check, torch.cat([R, K], dim=1))
            gradient_dummy = torch.ones_like(out)
            out.backward(gradient=gradient_dummy)
            f_K = K + eta * K.grad
            #f_F = torch.normal(FYA(S, R, f_K, loader.dataset), 1)
            f_F = FYA(S, R, f_K, loader.dataset)
            
            cK.requires_grad = True
            ceplision = torch.randn(1).cuda()
            #cF_check = FYA(S, R, cK, loader.dataset) + ceplision
            cF_check = FYA(S, R, cK, loader.dataset)
            cout = model(cF_check, torch.cat([R, cK], dim=1))
            cgradient_dummy = torch.ones_like(cout)
            cout.backward(gradient=cgradient_dummy)
            f_cK = cK + eta * cK.grad
            #f_cF = torch.normal(FYA(1 - S, R, f_cK, loader.dataset), 1)
            f_cF = FYA(1 - S, R, f_cK, loader.dataset)
            
            fair += torch.abs(f_F - f_cF).detach().clone()
            
            if i == 0:
                Y_prime.append(f_F[0].item())
                cY_prime.append(f_cF[0].item())
        
        o_fair = o_fair / 500
        fair = fair / 500
        
        original_unfairs.extend([o_fair[i].item() for i in range(o_fair.size()[0])])
        unfairs.extend([fair[i].item() for i in range(fair.size()[0])])
    
    original_unfairness = np.mean(original_unfairs)
    unfairness = np.mean(unfairs)
    loss = running_loss / (500 * len(loader))
    
    with open("DF_Y_{}.pkl".format(ratio), "wb") as f:
        pickle.dump(Y_prime, f)
    with open("DF_cY_{}.pkl".format(ratio), "wb") as f:
        pickle.dump(cY_prime, f)
    
    return original_unfairness, unfairness, loss

def DFexperiment(loaders, total_epochs, eta, ratio):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = DFdecision(input_dim=9)
    model.cuda()
    model.w1 = torch.tensor(1 / (2 * ratio * eta * train_loader.dataset.wFK ** 2))
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    DFtrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = DFtest(test_loader, model, loss_fn, eta, ratio)
    print("original_unfairness = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))
    return original_unfairness, unfairness, test_loss

def compute_table():
    seeds = [42, 43, 44, 45, 46]
    for seed in seeds:
        with open("./datas/data_{}.pkl".format(seed), "rb") as f:
            law_data = pickle.load(f)
        
        train_set = lawDataset(law_data, type="train")
        valid_set = lawDataset(law_data, type="valid")
        test_set = lawDataset(law_data, type="test")
        train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
        valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False)
        test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
        loaders = [train_loader, valid_loader, test_loader]
        
        total_epochs = 100
        eta = 10
        
        for i in range(2, 3):
            if i == 0:
                UFexperiment(loaders, total_epochs, eta)
            elif i == 1:
                CFexperiment(loaders, total_epochs, eta)
            elif i == 2:
                DFexperiment(loaders, total_epochs, eta, ratio=1)
        print("\n")

def drawDensity():
    with open("./datas/data_42.pkl", "rb") as f:
        law_data = pickle.load(f)
    
    train_set = lawDataset(law_data, type="train")
    valid_set = lawDataset(law_data, type="valid")
    test_set = lawDataset(law_data, type="test")
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
    loaders = [train_loader, valid_loader, test_loader]
    
    total_epochs = 100
    eta = 10
    
    #for i in range(4):
    #    if i == 0:
    #        UFexperiment(loaders, total_epochs, eta)
    #    elif i == 1:
    #        CFexperiment(loaders, total_epochs, eta)
    #    elif i == 2:
    #        DFexperiment(loaders, total_epochs, eta, 1)
    #    else:
    #        DFexperiment(loaders, total_epochs, eta, 2)
                
    with open("UF_Y.pkl", "rb") as f:
        UF_Y = pickle.load(f)
    with open("UF_cY.pkl", "rb") as f:
        UF_cY = pickle.load(f)
    with open("CF_Y.pkl", "rb") as f:
        CF_Y = pickle.load(f)
    with open("CF_cY.pkl", "rb") as f:
        CF_cY = pickle.load(f)
    with open("DF_Y_1.pkl", "rb") as f:
        DF_Y = pickle.load(f)
    with open("DF_cY_1.pkl", "rb") as f:
        DF_cY = pickle.load(f)
    with open("current_Y.pkl", "rb") as f:
        current_Y = pickle.load(f)
    with open("current_cY.pkl", "rb") as f:
        current_cY = pickle.load(f)
    with open("DF_Y_2.pkl", "rb") as f:
        DF_Y_2 = pickle.load(f)
    with open("DF_cY_2.pkl", "rb") as f:
        DF_cY_2 = pickle.load(f)
    
    factuals = [current_Y, UF_Y, CF_Y, DF_Y_2, DF_Y]
    counters = [current_cY, UF_cY, CF_cY, DF_cY_2, DF_cY]
    
    print("current {} ~ {}".format(min(current_Y), max(current_cY)))
    print("UF {} ~ {}".format(min(UF_Y), max(UF_Y)))
    print("CF {} ~ {}".format(min(CF_Y), max(CF_Y)))
    print("DF_2 {} ~ {}".format(min(DF_Y_2), max(DF_cY_2)))
    print("DF {} ~ {}".format(min(DF_Y), max(DF_Y)))

    plt.rcParams['font.size'] = '45'
    plt.rcParams["font.family"] = "normal"
    plt.rcParams['text.usetex'] = True

    def plot_single_graph(i, factual_data, counter_data, ax, title):
        kde1 = gaussian_kde(factual_data)
        kde2 = gaussian_kde(counter_data)
        
        if i == 0:
            x = np.linspace(-0.15, 0.36, 100)
        elif i == 1:
            x = np.linspace(0.1, 0.55, 100)
        elif i == 2:
            x = np.linspace(0.85, 1.34, 100)
        elif i == 3:
            x = np.linspace(0.80, 1.45, 100)
        elif i == 4:
            x = np.linspace(0.74, 1.58, 100)
        density1 = kde1(x)
        density2 = kde2(x)

        ax.plot(x, density1, color="b", label="factual")
        ax.plot(x, density2, color="r", label="counterfactual")

        ax.fill_between(x, density1, color='b', alpha=0.2)
        ax.fill_between(x, density2, color='r', alpha=0.2)
        
        #ax.legend()
        if i != 0:
            ax.set_xlabel("$Y'$")
        else:
            ax.set_xlabel("$Y$")
        ax.set_ylabel("density")
        ax.set_title(title)

    fig, axs = plt.subplots(1, 5, figsize=(50, 10.5))


    #titles = ["Baseline", "CA", "ICA", "CE", "CR", "Ours"]
    titles = ["Current", "UF", "CF", r"DF($p_{1}=\frac{T}{4}$)", r"DF($p_{1}=\frac{T}{2}$)"]
    for i, title in enumerate(titles):
        plot_single_graph(i, factuals[i], counters[i], axs[i], title)
    #plt.subplots_adjust(left=None, bottom=0.15, right=None, top=0.85, wspace=0.3, hspace=0.8)
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=0.5)
    lineA = mlines.Line2D([], [], color='b', label='factual')
    lineB = mlines.Line2D([], [], color='r', label='counter\nfactual')
    # Create a legend for the whole figure
    fig.legend(handles=[lineA, lineB], loc='center right', ncol=1, framealpha=1)

    plt.show()
    plt.savefig("density.pdf")
    #plt.savefig("density_cvae.pdf")

def plot_trade_off():
    unfairs = np.zeros((3, 6))
    losses = np.zeros((3, 6))

    with open("./datas/data_42.pkl", "rb") as f:
        law_data = pickle.load(f)
    
    train_set = lawDataset(law_data, type="train")
    valid_set = lawDataset(law_data, type="valid")
    test_set = lawDataset(law_data, type="test")
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
    loaders = [train_loader, valid_loader, test_loader]
    
    
    etas = [15, 10, 5]
    ratios = [1, 2, 4, 16, 32, 256]
    
    for i, eta in enumerate(etas):
        for j, ratio in enumerate(ratios):
            total_epochs = 50
            __, unfairness, loss = DFexperiment(loaders, total_epochs, eta, ratio)
            unfairs[i, j] = unfairness
            losses[i, j] = loss
    
    font = {"family": "normal", "size": 13}
    matplotlib.rc("font", **font)

    fig = plt.figure(figsize=(5, 5))
    p1, = plt.plot(losses[0], unfairs[0], "r*-", label="$\eta=15$", linewidth=2)
    p2, = plt.plot(losses[1], unfairs[1], "c>:", label="$\eta=10$", linewidth=2)
    p3, = plt.plot(losses[2], unfairs[2], "y^-", label="$\eta=5$", linewidth=2)
    #p4, = plt.plot(losses[3], unfairs[3], "bx--", label="$\eta=0.5$", linewidth=2)
    plt.legend()

    plt.xlabel("MSE")
    plt.ylabel("AFCE")
    plt.grid()
    plt.tight_layout()
    plt.show()
    fig.savefig("trade_off_curves.png")

def main():
    seeds = [42, 43, 44, 45, 46]
    for seed in seeds:
        check_data = Path("./datas/data_{}.pkl".format(seed))
        if check_data.is_file():
            pass
        else:
            data_generation(seed)
    
    #compute_table()
    #drawDensity()
    plot_trade_off()

if __name__ == "__main__":
    main()


"""class UFDecision(nn.Module):
    def __init__(self, input_dim):
        super(UFDecision, self).__init__()
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        out = self.layer(x)
        return out

def UFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta):
    train_losses = []
    valid_losses = []
    valid_unfairs = []
    
    for epoch in range(total_epochs):
        model.train()
        running_loss = 0
        
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            S = data[:, 0 : 2].cuda()
            R = data[:, 2 : 10].cuda()
            G = data[:, 11].unsqueeze(1).cuda()
            L = data[:, 12].unsqueeze(1).cuda()
            F = data[:, 13].unsqueeze(1).cuda()
            input_x = torch.cat([S, R, G, L], dim=1)
            out = model(input_x)
            loss = loss_fn(out, F)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        
        train_losses.append(running_loss / len(train_loader))
        __, valid_unfairness, valid_loss = UFTest(valid_loader, model, loss_fn, eta)
        valid_losses.append(valid_loss)
        valid_unfairs.append(valid_unfairness)
    
    plt.plot(train_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/UF_train_loss.png")
    plt.close()
    plt.plot(valid_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/UF_valid_loss.png")
    plt.close()
    plt.plot(valid_unfairs)
    plt.xlabel("Epoch")
    plt.ylabel("Unfairness")
    plt.savefig("images/UF_valid_unfairness.png")
    plt.close()
    
def UFTest(loader, model, loss_fn, eta):
    model.eval()
    original_unfairs = []
    unfairs = []
    running_loss = 0
    store_F = []
    store_cF = []
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].cuda()
        R = data[:, 2 : 10].cuda()
        K = data[:, 10].unsqueeze(1).cuda()
        G = data[:, 11].unsqueeze(1).cuda()
        L = data[:, 12].unsqueeze(1).cuda()
        F = data[:, 13].unsqueeze(1).cuda()
        
        cS = data[:, 14 : 16].cuda()
        cR = data[:, 16 : 24].cuda()
        cK = data[:, 10].unsqueeze(1).cuda()
        cF = data[:, 26].unsqueeze(1).cuda()
        
        original_unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
        with torch.no_grad():
            input_x = torch.cat([S, R, G, L], dim=1)
            out = model(input_x)
            loss = loss_fn(out, F)
            running_loss += loss.item()
        
        K.requires_grad = True
        G = GPA(S, R, K, loader.dataset)
        L = SAT(S, R, K, loader.dataset)
        input_x = torch.cat([S, R, G, L], dim=1)
        out = model(input_x)
        gradient_dummy = torch.ones_like(out)
        out.backward(gradient=gradient_dummy)
        grad_K = K.grad
        
        cK.requires_grad = True
        cG = GPA(cS, cR, cK, loader.dataset)
        cL = SAT(cS, cR, cK, loader.dataset)
        input_cx = torch.cat([cS, cR, cG, cL], dim=1)
        cout = model(input_cx)
        cgradient_dummy = torch.ones_like(cout)
        cout.backward(gradient=cgradient_dummy)
        grad_cK = cK.grad
        
        K = K + eta * grad_K
        F = FYA(S, R, K, loader.dataset)
        cK = cK + eta * grad_cK
        cF = FYA(cS, cR, cK, loader.dataset)
        
        unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
    original_unfairness = np.mean(np.abs(original_unfairs))
    unfairness = np.mean(np.abs(unfairs))
    loss = running_loss / len(loader)
    return original_unfairness, unfairness, loss

def UFExperiment(loaders, total_epochs, eta):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = UFDecision(input_dim=12).cuda()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    UFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = UFTest(test_loader, model, loss_fn, eta)
    print("original_unfairnss = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))

class CFDecision(nn.Module):
    def __init__(self, input_dim):
        super(CFDecision, self).__init__()
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        out = self.layer(x)
        return out

def CFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta):
    train_losses = []
    valid_losses = []
    valid_unfairs = []
    
    for epoch in range(total_epochs):
        model.train()
        running_loss = 0
        
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            S = data[:, 0 : 2].cuda()
            R = data[:, 2 : 10].cuda()
            K = data[:, 10].unsqueeze(1).cuda()
            G = data[:, 11].unsqueeze(1).cuda()
            L = data[:, 12].unsqueeze(1).cuda()
            F = data[:, 13].unsqueeze(1).cuda()
            
            cG = data[:, 24].unsqueeze(1).cuda()
            cL = data[:, 25].unsqueeze(1).cuda()
            
            #input_x = torch.cat([R, K, (G + cG) / 2, (L + cL) / 2], dim=1)
            input_x = torch.cat([R, K], dim=1)
            out = model(input_x)
            loss = loss_fn(out, F)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
            
        train_losses.append(running_loss / len(train_loader))
        __, valid_unfairness, valid_loss = CFTest(valid_loader, model, loss_fn, eta)
        valid_losses.append(valid_loss)
        valid_unfairs.append(valid_unfairness)
    
    plt.plot(train_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/CF_train_loss.png")
    plt.close()
    plt.plot(valid_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/CF_valid_loss.png")
    plt.close()
    plt.plot(valid_unfairs)
    plt.xlabel("Epoch")
    plt.ylabel("Unfairness")
    plt.savefig("images/CF_valid_unfairness.png")
    plt.close()
    
    
def CFTest(loader, model, loss_fn, eta):
    model.eval()
    original_unfairs = []
    unfairs = []
    running_loss = 0
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].cuda()
        R = data[:, 2 : 10].cuda()
        K = data[:, 10].unsqueeze(1).cuda()
        G = data[:, 11].unsqueeze(1).cuda()
        L = data[:, 12].unsqueeze(1).cuda()
        F = data[:, 13].unsqueeze(1).cuda()
        
        cS = data[:, 14 : 16].cuda()
        cR = data[:, 16 : 24].cuda()
        cK = data[:, 10].unsqueeze(1).cuda()
        cG = data[:, 24].unsqueeze(1).cuda()
        cL = data[:, 25].unsqueeze(1).cuda()
        cF = data[:, 26].unsqueeze(1).cuda()
        
        original_unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
        with torch.no_grad():
            #input_x = torch.cat([R, K, (G + cG) / 2, (L + cL) / 2], dim=1)
            input_x = torch.cat([R, K], dim=1)
            out = model(input_x)
            loss = loss_fn(out, F)
            running_loss += loss.item()
        
        K.requires_grad = True
        G = GPA(S, R, K, loader.dataset)
        G_check = GPA(cS, cR, cK, loader.dataset)
        L = SAT(S, R, K, loader.dataset)
        L_check = SAT(cS, cR, cK, loader.dataset)
        #input_x = torch.cat([R, K, (G + G_check) / 2, (L + L_check) / 2], dim=1)
        input_x = torch.cat([R, K], dim=1)
        out = model(input_x)
        gradient_dummy = torch.ones_like(out)
        out.backward(gradient=gradient_dummy)
        grad_K = K.grad
        
        cK.requires_grad = True
        G = GPA(S, R, K, loader.dataset)
        G = GPA(S, R, K, loader.dataset)
        G_check = GPA(cS, cR, cK, loader.dataset)
        L = SAT(S, R, K, loader.dataset)
        L_check = SAT(cS, cR, cK, loader.dataset)
        #input_cx = torch.cat([R, cK, (G + G_check) / 2, (L + L_check) / 2], dim=1)
        input_cx = torch.cat([R, cK], dim=1)
        cout = model(input_cx)
        cgradient_dummy = torch.ones_like(cout)
        cout.backward(gradient=cgradient_dummy)
        grad_cK = cK.grad
        
        K = K + eta * grad_K
        F = FYA(S, R, K, loader.dataset)
        cK = cK + eta * grad_cK
        cF = FYA(cS, cR, cK, loader.dataset)
        
        unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
    original_unfairness = np.mean(np.abs(original_unfairs))
    unfairness = np.mean(np.abs(unfairs))
    loss = running_loss / len(loader)
    return original_unfairness, unfairness, loss

def CFExperiment(loaders, total_epochs, eta):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    #model = CFDecision(input_dim=11).cuda()
    model = CFDecision(input_dim=9).cuda()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    CFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = CFTest(test_loader, model, loss_fn, eta)
    print("original_unfairnss = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))

class CRDecision(nn.Module):
    def __init__(self, input_dim):
        super(CRDecision, self).__init__()
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        out = self.layer(x)
        return out

def CRTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta):
    train_losses = []
    valid_losses = []
    valid_unfairs = []
    
    for epoch in range(total_epochs):
        model.train()
        running_loss = 0
        
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            S = data[:, 0 : 2].cuda()
            R = data[:, 2 : 10].cuda()
            K = data[:, 10].unsqueeze(1).cuda()
            G = data[:, 11].unsqueeze(1).cuda()
            L = data[:, 12].unsqueeze(1).cuda()
            F = data[:, 13].unsqueeze(1).cuda()
            
            cS = data[:, 14 : 16].cuda()
            cR = data[:, 16 : 24].cuda()
            cK = data[:, 10].unsqueeze(1).cuda()
            cG = data[:, 24].unsqueeze(1).cuda()
            cL = data[:, 25].unsqueeze(1).cuda()
            cF = data[:, 26].unsqueeze(1).cuda()
            
            input_x = torch.cat([S, R, G, L], dim=1)
            out = model(input_x)
            
            K_input = K.clone().detach().requires_grad_(True)
            G_input = GPA(S, R, K_input, train_loader.dataset)
            L_input = SAT(S, R, K_input, train_loader.dataset)
            out_intermediate = model(torch.cat([S, R, G_input, L_input], dim=1))
            grad_K = torch.autograd.grad(out_intermediate, K_input, grad_outputs=torch.ones_like(out_intermediate), retain_graph=True)[0]
            new_K = K_input + eta * grad_K
            new_F = FYA(S, R, new_K, train_loader.dataset)
            
            cK_input = cK.clone().detach().requires_grad_(True)
            cG_input = GPA(cS, cR, cK_input, train_loader.dataset)
            cL_input = SAT(cS, cR, cK_input, train_loader.dataset)
            cout_intermediate = model(torch.cat([cS, cR, cG_input, cL_input], dim=1))
            grad_cK = torch.autograd.grad(cout_intermediate, cK_input, grad_outputs=torch.ones_like(cout_intermediate), retain_graph=True)[0]
            new_cK = cK_input + eta * grad_cK
            new_CF = FYA(cS, cR, new_cK, train_loader.dataset)
            
            loss = loss_fn(out, F) + loss_fn(new_F, new_CF)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        train_losses.append(running_loss / len(train_loader))
        __, valid_unfairness, valid_loss = CRTest(valid_loader, model, loss_fn, eta)
        valid_losses.append(valid_loss)
        valid_unfairs.append(valid_unfairness)
    
    plt.plot(train_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/CR_train_loss.png")
    plt.close()
    plt.plot(valid_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/CR_valid_loss.png")
    plt.close()
    plt.plot(valid_unfairs)
    plt.xlabel("Epoch")
    plt.ylabel("Unfairness")
    plt.savefig("images/CR_valid_unfairness.png")
    plt.close()
    
def CRTest(loader, model, loss_fn, eta):
    model.eval()
    original_unfairs = []
    unfairs = []
    running_loss = 0
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].cuda()
        R = data[:, 2 : 10].cuda()
        K = data[:, 10].unsqueeze(1).cuda()
        G = data[:, 11].unsqueeze(1).cuda()
        L = data[:, 12].unsqueeze(1).cuda()
        F = data[:, 13].unsqueeze(1).cuda()
        
        cS = data[:, 14 : 16].cuda()
        cR = data[:, 16 : 24].cuda()
        cK = data[:, 10].unsqueeze(1).cuda()
        cG = data[:, 24].unsqueeze(1).cuda()
        cL = data[:, 25].unsqueeze(1).cuda()
        cF = data[:, 26].unsqueeze(1).cuda()
        
        original_unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
        with torch.no_grad():
            input_x = torch.cat([S, R, G, L], dim=1)
            out = model(input_x)
            loss = loss_fn(out, F)
            running_loss += loss.item()
        
        K.requires_grad = True
        G = GPA(S, R, K, loader.dataset)
        L = SAT(S, R, K, loader.dataset)
        input_x = torch.cat([S, R, G, L], dim=1)
        out = model(input_x)
        gradient_dummy = torch.ones_like(out)
        out.backward(gradient=gradient_dummy)
        grad_K = K.grad
        
        cK.requires_grad = True
        cG = GPA(cS, cR, cK, loader.dataset)
        cL = SAT(cS, cR, cK, loader.dataset)
        input_cx = torch.cat([cS, cR, cG, cL], dim=1)
        cout = model(input_cx)
        cgradient_dummy = torch.ones_like(cout)
        cout.backward(gradient=cgradient_dummy)
        grad_cK = cK.grad
        
        new_K = K + eta * grad_K
        new_F = FYA(S, R, new_K, loader.dataset)
        new_cK = cK + eta * grad_cK
        new_cF = FYA(cS, cR, new_cK, loader.dataset)
        
        unfairs.extend([(new_F[i] - new_cF[i]).item() for i in range(new_F.size(0))])
    
    original_unfairness = np.mean(np.abs(original_unfairs))
    unfairness = np.mean(np.abs(unfairs))
    loss = running_loss / len(loader)
    return original_unfairness, unfairness, loss
 
def CRExperiment(loaders, total_epochs, eta):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = CRDecision(input_dim=12).cuda()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    CRTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = CRTest(test_loader, model, loss_fn, eta)
    print("original_unfairnss = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))        

class DFDecision(nn.Module):
    def __init__(self, input_dim):
        super(DFDecision, self).__init__()
        self.w1 = torch.Tensor([0]).cuda()
        self.w2 = nn.Parameter(torch.randn(1))
        self.w3 = nn.Parameter(torch.randn(1))
        
        self.layer = nn.Linear(input_dim, 1)
    
    def forward(self, cy, u):
        h1 = self.w1 * cy * cy + self.w2 * cy + self.w3
        h2 = self.layer(u)
        out = h1 + h2
        return out

def DFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta):
    train_losses = []
    valid_losses = []
    valid_unfairs = []
    
    for epoch in range(total_epochs):
        model.train()
        running_loss = 0
        
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            S = data[:, 0 : 2].cuda()
            R = data[:, 2 : 10].cuda()
            K = data[:, 10].unsqueeze(1).cuda()
            F = data[:, 13].unsqueeze(1).cuda()
            cF = data[:, 26].unsqueeze(1).cuda()
            
            u = torch.cat([R, K], dim=1)
            out = model(cF, u)
            loss = loss_fn(out, F)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
            
        train_losses.append(running_loss / len(train_loader))
        __, valid_unfairness, valid_loss = DFTest(valid_loader, model, loss_fn, eta)
        valid_losses.append(valid_loss)
        valid_unfairs.append(valid_unfairness)
    
    plt.plot(train_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/DF_train_loss.png")
    plt.close()
    plt.plot(valid_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig("images/DF_valid_loss.png")
    plt.close()
    plt.plot(valid_unfairs)
    plt.xlabel("Epoch")
    plt.ylabel("Unfairness")
    plt.savefig("images/DF_valid_unfairness.png")
    plt.close()    

def DFTest(loader, model, loss_fn, eta):
    model.eval()
    original_unfairs = []
    unfairs = []
    running_loss = 0
    K_changes = []
    K_origins = []
    
    for i, data in enumerate(loader):
        S = data[:, 0 : 2].cuda()
        R = data[:, 2 : 10].cuda()
        K = data[:, 10].unsqueeze(1).cuda()
        G = data[:, 11].unsqueeze(1).cuda()
        L = data[:, 12].unsqueeze(1).cuda()
        F = data[:, 13].unsqueeze(1).cuda()
        
        cS = data[:, 14 : 16].cuda()
        cR = data[:, 16 : 24].cuda()
        cK = data[:, 10].unsqueeze(1).cuda()
        cG = data[:, 24].unsqueeze(1).cuda()
        cL = data[:, 25].unsqueeze(1).cuda()
        cF = data[:, 26].unsqueeze(1).cuda()
        
        original_unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])
        
        with torch.no_grad():
            u = torch.cat([R, K], dim=1)
            out = model(cF, u)
            loss = loss_fn(out, F)
            running_loss += loss.item()
        
        K.requires_grad = True
        F_check = FYA(cS, cR, K, loader.dataset)
        U = torch.cat([R, K], dim=1)
        out = model(F_check, U)
        gradient_dummy = torch.ones_like(out)
        out.backward(gradient=gradient_dummy)
        grad_K = K.grad
        K_origins.extend([K[i].item() for i in range(K.size(0))])
        K_changes.extend([(eta * grad_K[i]).item() for i in range(grad_K.size(0))])

        
        cK.requires_grad = True
        F_cc = FYA(S, R, cK, loader.dataset)
        cU = torch.cat([cR, cK], dim=1)
        cout = model(F_cc, cU)
        cgradient_dummy = torch.ones_like(cout)
        cout.backward(gradient=cgradient_dummy)
        grad_cK = cK.grad
        
        K = K + eta * grad_K
        F = FYA(S, R, K, loader.dataset)
        cK = cK + eta * grad_cK
        cF = FYA(cS, cR, cK, loader.dataset)
        
        unfairs.extend([(F[i] - cF[i]).item() for i in range(F.size(0))])

    change_ratio = np.sum(np.abs(K_changes)) / np.sum(np.abs(K_origins))
    #print("change ratio = {}".format(change_ratio))    
    original_unfairness = np.mean(np.abs(original_unfairs))
    unfairness = np.mean(np.abs(unfairs))
    loss = running_loss / len(loader)
    return original_unfairness, unfairness, loss

def DFExperiment(loaders, total_epochs, eta, args):
    train_loader, valid_loader, test_loader = loaders[0], loaders[1], loaders[2]
    model = DFDecision(input_dim=9)
    model.cuda()
    #model.w1 = torch.tensor(1 / (2 * 16 * eta * (np.sum(simdata.wxy * simdata.wxy * simdata.wux * simdata.wux) + np.sum(simdata.wuy * simdata.wuy)))).cuda()
    model.w1 = torch.tensor(1 / (2 * args.ratio * eta * train_loader.dataset.wFK ** 2))
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    DFTrain(train_loader, valid_loader, model, loss_fn, optimizer, total_epochs, eta)
    original_unfairness, unfairness, test_loss = DFTest(test_loader, model, loss_fn, eta)
    print("original_unfairness = {}, unfairness = {}, test_loss = {}".format(original_unfairness, unfairness, test_loss))

def main(args):
    check_data = Path("./datas/data_{}.pkl".format(args.seed))
    if check_data.is_file():
        pass
    else:
        data_generation(args.seed)
    for i in range(4):
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        
        with open("./datas/data_{}.pkl".format(args.seed), "rb") as f:
            law_data = pickle.load(f)
        train_set = lawDataset(law_data, type="train")
        valid_set = lawDataset(law_data, type="valid")
        test_set = lawDataset(law_data, type="test")
        train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
        valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False)
        test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
        loaders = [train_loader, valid_loader, test_loader]
        
        total_epochs = 200
        eta = args.eta
        if i == 0:
            UFExperiment(loaders, total_epochs, eta)
        elif i == 1:
            CFExperiment(loaders, total_epochs, eta)
        elif i == 2:
            CRExperiment(loaders, total_epochs, eta)
        else:
            DFExperiment(loaders, total_epochs, eta, args)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="law school experiment")
    parser.add_argument("--seed", type=int, default=42, help="experiment seed")
    parser.add_argument("--eta", type=float, default=1, help="eta")
    parser.add_argument("--ratio", type=int, default=2, help="ratio of improvement")
    args = parser.parse_args()
    main(args)"""