import pickle
import sys
from torch.utils.data import Dataset
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt

from src.models.utils import *
from src.config import *
from src.models.exist_models import *

IS_ON_CLUSTER = False

def eval_res(y_preds, y_trues, task='reg'):
    if task == "reg":
        return 1 - np.mean(np.absolute(y_trues - y_preds) / y_trues)
    else:
        print("Classes in Test:", set(y_trues))
        if len(set(y_trues)) <= 2:
            return roc_auc_score(y_trues, y_preds[:, 1])
        else:
            # for i in range(len(y_trues)):
            #     print(y_trues[i], np.argmax(y_preds[i]))
            return roc_auc_score(y_trues, y_preds, multi_class="ovo", average="macro")
            # return f1_score(y_trues, np.argmax(y_preds, axis=1), average="micro")

def load_dataset(ds_name="wesad"):
    f_path = "data/{}/splits".format(ds_name)
    if IS_ON_CLUSTER:
        f_path = "../../" + f_path
    
    # read split
    with open(f_path, 'rb') as f:
        split = pickle.load(f)
    print("Train size:", len(split["train_fnames"]))
    print("Test size:", len(split["test_fnames"]))
    return split["train_fnames"], split["test_fnames"]

def load_model(
    model_name="convtran",
    num_channels=14,
    num_classes=1,
    task="reg",
    is_preproc=False
):
    if model_name == "cv":
        model = CVModel(num_channels=num_channels, num_classes=num_classes, task=task, is_preproc=is_preproc, vanilla=True)
    elif model_name == "vit":
        model = CVModel(num_channels=num_channels, num_classes=num_classes, task=task, is_preproc=is_preproc, all_attn=True)
    elif model_name in ["chronos", 'chronos_last', 'chronos_msitf', 'chronos_base']:
        model = Chronos(num_channels, num_classes, task=task, is_preproc=is_preproc)
    elif model_name == 'uni2ts':
        pass # initialize chronos
    elif model_name == "pretrain":
        model = PretrainAPI(num_channels=num_channels, num_classes=num_classes, task=task, is_preproc=is_preproc, is_on_cluster=IS_ON_CLUSTER)
        # print(model)
        # print("Load Success!")
        # exit()
    elif model_name == "nlp":
        pass # TODO initialize ts+nlp model
    elif model_name in ['fusion', 'fusion_crossvit']:
        model = FusionModel(num_classes=num_classes, task=task, is_preproc=is_preproc)
    elif model_name == "crossvit":
        model = CrossVitAPI(num_channels=num_channels, num_classes=num_classes, task=task, is_preproc=is_preproc)
    elif model_name == 'mae_msitf':
        model = MAE_API(num_classes, task=task, is_preproc=is_preproc)
    elif model_name == 'mae':
        model = MAE_LP_API(num_channels, num_classes, task=task, is_preproc=is_preproc)
    
    return model.to(torch.bfloat16).to(DEVICE)

class Downstream_Dataset(Dataset):
    def __init__(self, fnames, ts=True, cwt=False, cwtp=False, nlp=False, task='reg', model_name="", is_preproc=False): 
        self.fnames = fnames
        self.task = task
        self.model_name = model_name
        self.is_preproc = is_preproc

        self.ts = ts
        self.cwt = cwt
        self.cwtp = cwtp
        self.nlp = nlp

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        f_path = self.fnames[idx].replace("\\", "/")
        if IS_ON_CLUSTER:
            f_path = "../../" + f_path

        # read split
        with open(f_path, 'rb') as f:
            curr_sample = pickle.load(f)

        # load semantic embeddings
        with open(f_path.replace("samples", "nlp_embed"), 'rb') as f:
            curr_nlp_embed = pickle.load(f)
            # print(type(curr_nlp_embed['question']))
            # exit()

        if self.is_preproc:
            if self.ts:
                sample = curr_sample["tss"] # C, L
            elif self.cwt:
                sample = curr_sample["cwt"][:, :, :, 0] # C, L, 65
            elif self.cwtp:
                sample = curr_sample["cwt"] # C, L, 65, 3
            elif self.nlp:
                sample = curr_sample["nlp_embed"] # Le, 768 (Le=num embeddings)
            try:
                sample = torch.from_numpy(sample)
            except:
                sample = sample
        else:
            f_path = self.fnames[idx].replace("samples", self.model_name).replace("\\", "/")
            if IS_ON_CLUSTER:
                f_path = "../../" + f_path
            with open(f_path, 'rb') as f:
                sample = pickle.load(f) # C, E
            # sample = sample.float()
            sample = torch.from_numpy(sample)
        
        # process label
        label = curr_sample["label"]
        if self.task == "reg":
            labels = torch.tensor(label).to(torch.bfloat16)
        else:
            labels = torch.tensor(label).long()

        return {
            "samples": sample.to(torch.bfloat16).to(DEVICE),
            "labels": labels.to(DEVICE),
            "fnames": self.fnames[idx],
            "query": torch.from_numpy(curr_nlp_embed['question']).to(torch.bfloat16).to(DEVICE)
        }

def main(
    is_preproc=False,
    model_name="convtran",
    ds_name="wesad",
    num_channels=14,
    num_classes=1,
    task="class",
    ts=False, 
    cwt=False, 
    cwtp=False, 
    nlp=False
):  
    
    # initialize dataset
    train_fnames, test_fnames = load_dataset(ds_name=ds_name)
    train_dataset = Downstream_Dataset(train_fnames, ts=ts, cwt=cwt, cwtp=cwtp, nlp=nlp, task=task, model_name=model_name, is_preproc=is_preproc)
    test_dataset = Downstream_Dataset(test_fnames, ts=ts, cwt=cwt, cwtp=cwtp, nlp=nlp, task=task, model_name=model_name, is_preproc=is_preproc)

    # initialize model
    model = load_model(model_name=model_name, num_classes=num_classes, task=task, num_channels=num_channels, is_preproc=is_preproc)
    num_params(model)

    # if get and save embed only
    # batch_size = 64 if not is_preproc else 1
    batch_size = 16
    if is_preproc:
        get_and_save_embed(model, train_dataset, ds_name, model_name, batch_size=batch_size, is_on_cluster=IS_ON_CLUSTER)
        get_and_save_embed(model, test_dataset, ds_name, model_name, batch_size=batch_size, is_on_cluster=IS_ON_CLUSTER)
        print("Processing Complete.")
        return

    # fit model
    scores = list()
    for i in range(5):
        # re-initialize model
        model = load_model(model_name=model_name, num_classes=num_classes, task=task, num_channels=num_channels, is_preproc=is_preproc)
        # num_params(model)

        record = fit_model(
            model_name,
            model,
            train_dataset,
            test_dataset,
            batch_size=batch_size,
            epochs=100, # tune, e,g, 30, 50, 150
            # lr=DATASET_CONFIG[ds_name]['lr'], # tune
            lr=5e-4, # tune, e.g. 1e-2, 5e-3
            # weight_decay=1e-4,
            step_size=10, # tune, e.g. 5, 10
        )
        # print(record["last_pred"]["y_preds"])
        # print(set(np.array(record["last_pred"]["y_trues"])[:, 0]))
        # exit()
        scores.append(eval_res(np.array(record["last_pred"]["y_preds"]), np.array(record["last_pred"]["y_trues"]), task=task))
        print("Last score achieved:", scores[-1])

    print("{}, {}, score: {} +- {}".format(ds_name, model_name, round(np.mean(scores)*100, 3), round(np.std(scores)*100, 3)))

    # save
    # saved_weight_path = "data/{}/fuse_weight.pt".format(ds_name)
    # torch.save(model.fuse.to(torch.device('cpu')).state_dict(), saved_weight_path)
    saved_weight_path = "../data/{}/linear_prob_weight.pt".format(ds_name)
    torch.save(model.linear_prob.to(torch.device('cpu')).state_dict(), saved_weight_path)
    # with open("data/{}/linear_prob_weight.pkl".format(ds_name), 'wb') as f:
    #     pickle.dump(model.linear_prob.weight.detach().cpu().numpy(), f)
    # print(model.linear_prob.weight.shape)

    # # plot
    # if not IS_ON_CLUSTER:
    #     plt.plot(record["train_i"], record["train_losses"], label="Train")
    #     plt.plot(record["eval_i"], record["eval_losses"], label="Eval")
    #     plt.legend()
    #     plt.show()
    # plt.savefig("track_loss.png")

if __name__ == "__main__":
    # python3 -m src.downstream_eval uci_har vit False True
    # python3 -m src.downstream_eval wesad pretrain True True
    # python3 -m src.downstream_eval ecg_heart_cat vit False True
    # python3 -m src.downstream_eval ppg_hgb crossvit True False
    # python3 -m src.downstream_eval non_invasive_bp vit True False
    # python3 -m src.downstream_eval drive_fatigue fusion False False

    '''
    # preprocess command
    python3 -m src.downstream_eval ppg_hgb chronos_base False True
    python3 -m src.downstream_eval non_invasive_bp mae False True
    python3 -m src.downstream_eval indian-fPCG vit True True
    python3 -m src.downstream_eval PPG_HTN chronos_base False True
    python3 -m src.downstream_eval PPG_DM chronos_base False True
    python3 -m src.downstream_eval PPG_CVA mae True True
    python3 -m src.downstream_eval PPG_CVD chronos_base False True

    python3 -m src.downstream_eval ecg_heart_cat mae False False
python3 -m src.downstream_eval drive_fatigue mae True True
python3 -m src.downstream_eval gameemo mae True True
python3 -m src.downstream_eval uci_har mae True True
python3 -m src.downstream_eval wesad mae True True

    python3 -m src.downstream_eval ppg_hgb chronos_msitf False
    python3 -m src.downstream_eval non_invasive_bp chronos_msitf False
    python3 -m src.downstream_eval PPG_HTN chronos_msitf False
    python3 -m src.downstream_eval PPG_DM mae_msitf True True
    python3 -m src.downstream_eval PPG_CVA chronos_msitf False
    python3 -m src.downstream_eval PPG_CVD chronos_msitf False
    python3 -m src.downstream_eval ecg_heart_cat mae True
    python3 -m src.downstream_eval ecg_heart_cat mae_msitf True
    python3 -m src.downstream_eval drive_fatigue mae_msitf True True
    python3 -m src.downstream_eval gameemo chronos_msitf False
    python3 -m src.downstream_eval uci_har chronos_msitf False
    python3 -m src.downstream_eval wesad chronos_msitf False

    python3 -m src.downstream_eval indian-fPCG chronos False
    python3 -m src.downstream_eval indian-fPCG cv False
    python3 -m src.downstream_eval indian-fPCG vit False
    python3 -m src.downstream_eval gameemo fusion False
    python3 -m src.downstream_eval indian-fPCG crossvit False
    '''
    ds_name = sys.argv[1]
    model_name = sys.argv[2]
    is_preproc = True if sys.argv[3] == "True" else False

    try:
        is_on_cluster = True if sys.argv[4] == "True" else False
    except:
        is_on_cluster = False
    IS_ON_CLUSTER = is_on_cluster

    # print(is_preproc, IS_ON_CLUSTER)
    # exit()

    # ds_name = 'wesad' # non_invasive_bp, ppg_hgb, ecg_heart_cat
    # model_name = 'cv' # cv, chronos, moirai, pretrain, nlp
    # is_preproc = True # True for process and save embed

    print("Processing {}...".format(ds_name))
    main(
        is_preproc=is_preproc,
        model_name=model_name,
        ds_name=ds_name,
        num_channels=DATASET_CONFIG[ds_name]["n_ch"], 
        num_classes=DATASET_CONFIG[ds_name]["n_cl"],
        task=DATASET_CONFIG[ds_name]["task"], # class, reg
        ts=MODEL_CONFIG[model_name]['ts'], 
        cwt=MODEL_CONFIG[model_name]['cwt'], 
        cwtp=MODEL_CONFIG[model_name]['cwtp'], 
        nlp=MODEL_CONFIG[model_name]['nlp']
    )