import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from data import FairDataModule
from dbm import DBMModule
import numpy as np
import warnings
from absl import app, flags
import torch.nn.functional as F
import torchvision.models as torch_models
import pytorch_lightning as pl
from architecture import FC, Net, FC_3, FC_5, ResNetFeatureExtractor, CLIPViTB16FeatureExtractor,CLIPRN50FeatureExtractor, GenderCNN
warnings.filterwarnings("ignore")





FLAGS = flags.FLAGS
flags.DEFINE_string("dataset", "adult", "dataset name")
flags.DEFINE_bool("bias_flag", False, "bias or not")
flags.DEFINE_float("c", 0.0, "bias ratio")
flags.DEFINE_float("bias_amount",0.2,"bias level")
flags.DEFINE_integer("hidden_dim", 50, "hidden dimension")
flags.DEFINE_integer("out_dim", 2, "output dimension")
flags.DEFINE_integer("epochs", 10, "no. of epochs")
flags.DEFINE_float("lr", 5e-2, "leraning rate")
flags.DEFINE_string("method", "dbm", "method to use")
flags.DEFINE_float("reg",0.1,"reg")
flags.DEFINE_integer("num_run", 10, "# of runs")
flags.DEFINE_integer("batch_size",128,"batch size")




def main(argv):

    data_location = "./dataset"
    if FLAGS.dataset == "adult":
        input_size = 34
    elif FLAGS.dataset == "compas":
        input_size = 10
    elif FLAGS.dataset == "lfwa_w":
        input_size = 160
    elif FLAGS.dataset == "celeba":
        input_size = 160

    print("METHOD USED: ", FLAGS.method)
    def create_model(input_size):

        if FLAGS.dataset == "lfwa_w":
            model = ResNetFeatureExtractor()
            target_model = FC(512,512,FLAGS.out_dim)
        elif FLAGS.dataset == "celeba":
            model = ResNetFeatureExtractor()
            target_model = FC(512, 512, FLAGS.out_dim)
        elif FLAGS.dataset == "adult":
            model = Net(input_size, FLAGS.hidden_dim, FLAGS.out_dim)
            model.load_state_dict(torch.load('pretrained_models/mlp_biased.pt'))
            target_model =  FC(FLAGS.hidden_dim, FLAGS.hidden_dim, FLAGS.out_dim)
        elif FLAGS.dataset == "compas":
            model = Net(input_size, FLAGS.hidden_dim, FLAGS.out_dim)
            model.load_state_dict(torch.load('pretrained_models/mlp_compas.pt'))
            target_model =  FC(FLAGS.hidden_dim, FLAGS.hidden_dim, FLAGS.out_dim)

        module = DBMModule(model, 
                           target_model,
                           FLAGS.hidden_dim,
                           learning_rate= FLAGS.lr,
                           regularization_weight= FLAGS.reg)

        return module


    results = {'err': [], 'DDP': [], 'DP': [], 'EO': []}

    for i in range(FLAGS.num_run):
        print(f"Run {i+1}/{FLAGS.num_run}")

        model = create_model(input_size)
        datamodule = FairDataModule(data_location, 
                                    FLAGS.dataset, 
                                    add_bias=FLAGS.bias_flag,
                                    bias_amount = FLAGS.bias_amount,
                                    batch_size = FLAGS.batch_size)
        trainer = pl.Trainer(max_epochs=FLAGS.epochs, 
                             enable_progress_bar=True,
                             accelerator='auto',  
                             devices='auto',  )
        trainer.fit(model, datamodule=datamodule)
        test_result = trainer.test(model, datamodule=datamodule)[0]
     
        results['err'].append(1 - test_result['test_acc'])
        results['DDP'].append(test_result['DDP'])
        results['EO'].append(test_result['EO'])

    def compute_mean_std(metric_list):
        metric_tensor = torch.tensor(metric_list)
        return metric_tensor.mean().item(), metric_tensor.std().item()

    mean_std_results = {metric: compute_mean_std(values) for metric, values in results.items()}


    print(f"Mean Test Accuracy: {(mean_std_results['err'][0])*100:.3f} ± {mean_std_results['err'][1]*100:.3f}")
    print(f"Mean DDP: {mean_std_results['DDP'][0]:.3f} ± {mean_std_results['DDP'][1]:.3f}")
    print(f"Mean EO: {mean_std_results['EO'][0]:.3f} ± {mean_std_results['EO'][1]:.3f}")



if __name__ == "__main__":
    app.run(main)
