import pickle
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import random

from importlib import import_module
from numpy import mean
from torch import cuda, device
from argparse import ArgumentParser

from reward_trees.reward_trees import RewardTree, RewardNet
from reward_trees.reward_trees.reward_learner import preference_tuple
from src.feature_funcs import (
    parse_recourse_df,
    parse_recourse_df_gc,
    make_split_features_and_thresholds,
)

# Step 1: Initialize the DataFrame
columns = ['dataset', 'seed', 'model', 'train', 'test', 'label_type']
data_df = pd.DataFrame(columns=columns)


def get_train_test_accuracy(the_model):
    costs = the_model(states, actions, next_states)
    accs = list()
        
    for l, d in (("TRAIN", the_model.preferences), ("TEST", test_preferences)):
        correct = []
        for p in d:
            c_i, c_j = costs[p.i].item(), costs[p.j].item()
            if p.y > 0.5:
                correct.append(1 if c_i > c_j else 0)
            elif p.y < 0.5:
                correct.append(1 if c_i < c_j else 0)
            else:
                correct.append(1 if c_i == c_j else 0)
        print(f"{l} ACC =", mean(correct))
        accs.append( mean(correct) )
    return accs        


parser = ArgumentParser()
parser.add_argument("--train_frac", type=int, default=0.8)
parser.add_argument("--num_batches", type=int, default=50000)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--loss_func", type=str, default="bce")
parser.add_argument("--max_num_leaves", type=int, default=50)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()

                    
# Step 3: Iterate and append data to the DataFrame
for seed in [0, 1, 2]:
    for dataset in ['heloc', 'adult', 'german_credit']:
        for model_class in ['mlp', 'tree']:
            for label_type in ['normal', 'custom']:

                # Assign values for iteration
                args.dataset = dataset
                args.model_class = model_class
                args.seed = seed

                with open(f"data/{args.dataset}/df_rec_"+label_type+".pkl", "rb") as f:
                    recourse_df = pickle.load(f)

                with open(f"data/{args.dataset}/df_idx_"+label_type+".pkl", "rb") as f:
                    preference_df = pickle.load(f)

                preference_df['rating'].replace(0, None, inplace=True)                        

                dev = device("cuda:0" if cuda.is_available() else "cpu")
                p = parse_recourse_df_gc if args.dataset == "german_credit" else parse_recourse_df
                states, actions, next_states, ep_nums = p(recourse_df, device=dev)

                ds = args.dataset.split("__")[0]
                split_features_and_thresholds = make_split_features_and_thresholds(
                    states, actions, next_states, import_module(f"src.features.{ds}").features
                )

                training_args = dict(
                    num_batches=args.num_batches,
                    batch_size=args.batch_size,
                )

                if args.model_class == "tree":
                    model = RewardTree(
                        features_and_thresholds=split_features_and_thresholds,
                        max_num_eps=len(ep_nums),
                        max_ep_length=1,
                        seed=args.seed,
                    )
                    training_args["max_num_leaves"] = args.max_num_leaves
                    training_args["loss_func"] = args.loss_func
                    training_args["callbacks"] = [get_train_test_accuracy]

                elif args.model_class == "mlp":
                    model = RewardNet(features=list(split_features_and_thresholds), seed=args.seed)
                else:
                    raise NotImplementedError

                model.add_transitions(states, actions, next_states, ep_nums)

                print("Preferance df shape and num links:", preference_df.shape)

                num_preferences = len(preference_df)
                num_train_preferences = num_preferences * args.train_frac

                # NOTE: do not shuffle sparse datasets to ensure that train graph is connected
                if "__sparse" in args.dataset:
                    shuffled_preference_df = preference_df
                else:
                    shuffled_preference_df = preference_df.sample(frac=1, random_state=args.seed)

                graph = nx.DiGraph()
                test_preferences = []
                n = 0
                for _, preference in shuffled_preference_df.iterrows():
                    if preference["rating"] == 1:  # "rating" is the one with higher cost
                        p = 1.0
                    elif preference["rating"] == 2:
                        p = 0.0
                    elif preference["rating"] is None:
                        print("TODO: check validity of None <-> equal preference")
                        p = 0.5

                    if n < num_train_preferences:
                        model.add_preference(preference["id1"], preference["id2"], p)
                        if p == 1.0:
                            graph.add_edge(preference["id2"], preference["id1"])
                        elif p == 0.0:
                            graph.add_edge(preference["id1"], preference["id2"])
                        elif p == 0.5:
                            graph.add_edge(preference["id1"], preference["id2"])
                            graph.add_edge(preference["id2"], preference["id1"])

                    else:
                        test_preferences.append(
                            preference_tuple(preference["id1"], preference["id2"], p, 1.0, {})
                        )

                    n += 1

                # assert nx.is_weakly_connected(graph)
                print(len(graph.nodes), len(graph.edges))
                print(len(test_preferences))

                model.train(**training_args)
                train_acc, test_acc = get_train_test_accuracy(model)

                model.test_preferences = test_preferences

                with open(f"models/{args.dataset}.{args.model_class}.{seed}.{label_type}", "wb") as f:
                    pickle.dump(model, f)

                new_data = [args.dataset, args.seed, args.model_class, train_acc, test_acc, label_type]
                new_row = pd.Series(new_data, index=columns)
                data_df = data_df.append(new_row, ignore_index=True)

                data_df.to_csv('data/cost_func_results.csv')



