import os
import torch
import numpy as np
import pandas as pd
from typing import Dict
from argparse import Namespace
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

from large_rl.commons.args import HISTORY_SIZE
from large_rl.embedding.base import BaseEmbedding
from large_rl.envs.recsim_data.user_model.reward_model import launch_RewardModel, RewardModel
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.utils import logging



def compute_metrics(pred: np.ndarray, y_true: np.ndarray):
    # y_pred = np.argmax(pred, axis=-1)  # for multi-class classification!
    y_pred = pred.astype(np.float32).ravel()
    y_pred = (y_pred > 0.5).astype(np.float32)
    correct_results_sum = float(sum(y_pred == y_true))
    acc = correct_results_sum / y_true.shape[0]
    f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
    prec = precision_score(y_true=y_true, y_pred=y_pred, average='macro')
    recall = recall_score(y_true=y_true, y_pred=y_pred, average='macro')
    metrics = {"accuracy": acc, "f1": f1, "precision": prec, "recall": recall}
    conf_mat = confusion_matrix(y_true=y_true, y_pred=y_pred)
    return metrics, conf_mat


def prep_input(data: dict, dict_embedding: Dict[str, BaseEmbedding], _name: str, args: Namespace):
    if args.recsim_reward_model_type == "normal":
        if _name == "all":
            emb = dict_embedding["item"].get_all(if_np=False)
            return emb
        user_history_feat = dict_embedding["item"].get(index=data[f"{_name}_history_seq"][args._idx], if_np=True)
        obs = np.concatenate([data[f"{_name}_user_feat"][args._idx], user_history_feat], axis=-1)
        obs = torch.tensor(obs, device=args.device)
        emb = dict_embedding["item"].get(index=data[f"{_name}_item"][args._idx], if_np=False)[:, 0, :]
        y = data[f"{_name}_label"][args._idx]
        return obs, emb, y
    elif args.recsim_reward_model_type == "simple":
        user_history_feat = dict_embedding["item"].get(index=data[f"{_name}_history_seq"][args._idx], if_np=True)
        user_feat = data[f"{_name}_user_feat"][args._idx]
        emb = dict_embedding["item"].get(index=data[f"{_name}_item"][args._idx], if_np=False)
        y = data[f"{_name}_label"][args._idx]
        return [user_feat, user_history_feat], emb, y


def get_data(df_log: pd.DataFrame, dict_embedding: Dict[str, BaseEmbedding]):
    history_seq, label, item = list(), list(), list()
    for row in df_log.iterrows():
        _history_seq = eval(row[1]["hist_seq"])
        _item = int(row[1]["itemId"])
        label.append(float(row[1]["click"]))
        item.append(_item)
        history_seq.append(_history_seq)
    label, item, history_seq = np.asarray(label), np.asarray(item)[:, None], np.asarray(history_seq)

    # get user attributes
    user_id = torch.tensor(df_log["userId"].values, dtype=torch.int64, device=args.device)
    user_feat = dict_embedding["user"].get(index=user_id, if_np=True)

    if args.recsim_reward_model_type == "normal":
        user_feat = np.tile(A=user_feat[:, None, :], reps=(1, HISTORY_SIZE, 1))

    logging(f"get_data>> Rate of click event: {sum(label) / len(label)}")
    logging(f"get_data>> user_feat: {user_feat.shape} user_history_feat: {history_seq.shape}, "
            f"item: {item.shape}, label: {label.shape}")
    return user_feat, history_seq, item, label


def even_negative_positive_event(df_log):
    return df_log
    # negatives = df_log[df_log["click"] == False]
    # # positives = df_log[df_log["click"] == True].sample(int(0.7 * negatives.shape[0]))
    # positives = df_log[df_log["click"] == True].sample(int(negatives.shape[0]))
    # logging(f"negative event: {negatives.shape[0]}, positive event: {positives.shape[0]}")
    # df_log = pd.concat([negatives, positives], ignore_index=True)
    # df_log = df_log.sample(frac=1).reset_index(drop=True)  # shuffle the rows
    # return df_log


def train(model, loss_fn, opt, lr_scheduler, data: dict, dict_embedding: Dict[str, BaseEmbedding], args: Namespace):
    logging("=== Train Model ===")
    model.train()
    train_idx = np.array_split(ary=np.arange(data["train_history_seq"].shape[0]),
                               indices_or_sections=data["train_history_seq"].shape[0] // args.batch_size)
    test_idx = np.array_split(ary=np.arange(data["test_history_seq"].shape[0]),
                              indices_or_sections=data["test_history_seq"].shape[0] // args.batch_size)
    for epoch in range(args.num_epochs):
        # === Evaluation ===
        if (epoch % 10) == 0:
            model.eval()
            with torch.no_grad():
                # Sample from the data
                train_metrics_list = list()
                test_metrics_list = list()
                from large_rl.commons.utils import mean_dict
                for flg in ["train", "test"]:
                    for _idx in eval(f"{flg}_idx"):
                        args._idx = _idx
                        inputs = prep_input(data=data, dict_embedding=dict_embedding, _name=flg, args=args)
                        pred = model.compute_score(*inputs[:-1])
                        _metrics, _conf_mat = compute_metrics(pred=pred.cpu().numpy(), y_true=inputs[-1])
                        if flg == "train":
                            conf_mat = _conf_mat
                            train_metrics_list.append(_metrics)
                        else:
                            conf_mat += _conf_mat
                            test_metrics_list.append(_metrics)

            logging("=== CONFUSION MATRIX ===")
            print(conf_mat)

            train_metrics = mean_dict(_list_dict=train_metrics_list)
            test_metrics = mean_dict(_list_dict=test_metrics_list)
            logging(f"[Eval: Train] epoch: {epoch} | {train_metrics}")
            logging(f"[Eval: Test] epoch: {epoch} | {test_metrics}")

            model.train()

        # === One Training Epoch ===
        loss = list()
        for _idx in train_idx:
            args._idx = _idx
            obs, emb, y = prep_input(data=data, dict_embedding=dict_embedding, _name="train", args=args)
            pred = model.compute_score(obs, emb)  # batch_size x 1
            _loss = loss_fn(pred.view(-1), torch.tensor(y, device=args.device).to(pred.dtype))
            opt.zero_grad()
            _loss.backward()
            opt.step()
            if lr_scheduler is not None: lr_scheduler.step()
            loss.append(_loss.item())

        # === After one epoch ===
        logging(f"[Train] epoch: {epoch} loss: {np.mean(loss):.3f}")

        """ === Save Model === """
        if (epoch == 0) or (epoch % int(args.num_epochs * 0.2)) == 0:
            state = model.state_dict()
            state["epoch"] = epoch
            rm_weight_path = f"{args.rm_weight_path}.pkl"
            torch.save(state, rm_weight_path)
            logging("Model is saved in {}".format(rm_weight_path))


def main(args):
    # Set the random seed
    set_randomSeed(seed=args.seed)

    # make the directory if it doesn't exist yet
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    df_log = pd.read_csv(os.path.join(args.recsim_data_dir, f"{args.recsim_pre_offline_or_online}_log.csv"))
    df_log = even_negative_positive_event(df_log=df_log)
    df_train, df_test = train_test_split(df_log, random_state=args.seed)
    from large_rl.envs.recsim_data.env import prep_emb
    dict_embedding = prep_emb(args=vars(args))

    train_user_feat, train_history_seq, train_item, train_label = get_data(df_log=df_train,
                                                                           dict_embedding=dict_embedding)
    test_user_feat, test_history_seq, test_item, test_label = get_data(df_log=df_test,
                                                                       dict_embedding=dict_embedding)
    data = {
        "train_user_feat": train_user_feat, "train_history_seq": train_history_seq, "train_label": train_label,
        "train_item": train_item,
        "test_user_feat": test_user_feat, "test_history_seq": test_history_seq, "test_label": test_label,
        "test_item": test_item,
    }

    args.rm_weight_path = ""
    model = launch_RewardModel(args=vars(args))
    # loss_fn = torch.nn.CrossEntropyLoss()
    # loss_fn = torch.nn.BCEWithLogitsLoss()
    loss_fn = torch.nn.BCELoss()
    params = [
        *model.model.parameters(),
        *model.obs_encoder.parameters(),
    ]
    import torch.optim as optim
    opt = optim.Adam(params, lr=args.rm_lr)
    # lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=0.9)
    # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=args.num_epochs)
    lr_scheduler = None
    args.rm_weight_path = os.path.join(args.save_dir, args.recsim_pre_offline_or_online)
    args.if_offline = args.recsim_pre_offline_or_online == "offline"
    train(model=model, loss_fn=loss_fn, opt=opt, lr_scheduler=lr_scheduler, data=data, dict_embedding=dict_embedding,
          args=args)


if __name__ == '__main__':
    from large_rl.commons.args import get_all_args, add_args

    args = get_all_args()  # Get the hyper-params

    # =========== DEBUG =======================
    args.env_name = "recsim"
    args = add_args(args=args)
    args.num_epochs = 100
    args.batch_size = 256
    args.recsim_reward_model_type = "normal"
    args.recsim_emb_type = "pretrained"
    args.rm_lr = 0.0005
    # args.recsim_data_dir = "./data/movielens/ml_100k/ml-100k"
    # args.save_dir = f"{args.recsim_data_dir}/trained_weight/"
    # args.recsim_rm_obs_enc_type = "deepset"
    # =========== DEBUG =======================

    args.recsim_emb_type = "pretrained"
    args.recsim_data_dir = os.path.join(DATASET_PATH, args.recsim_data_dir)
    args.user_embedding_path = os.path.join(DATASET_PATH, args.recsim_data_dir, "user_attr.npy")
    args.item_embedding_path = os.path.join(DATASET_PATH, args.recsim_data_dir, "trained_weight/item.npy")
    # args.save_dir = os.path.join(DATASET_PATH, args.save_dir)
    logging(args)

    main(args=args)
