import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dida")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dataloader")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/models")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
#########################################################################################
import os
import torch
import logging
import numpy as np
from numpy.linalg import matrix_rank
np.random.seed(42)

from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, listconfig

from dss import DSS
from dida_network import DIDA

import hydra
import global_variables
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import hparams
from d2v_ours import D2V_ours
from network import NetRanking
from evaluate import evaluate_ranking
from  utils_main import dump_all_python_files
from baseline_network import Baseline_linear_ranking
from dataloader_train import get_list_dataset_cc18
from utils_dataloader import get_dataloader_ranking, get_BO_dataloader

torch.backends.cudnn.deterministic = True
# A logger for this file
global_variables.log = logging.getLogger(__name__)
log = global_variables.log
global_variables.batch_idx = 0

def training(dataloader_train, model, baseline_model, training_params, dataloader_test, writer, dataloader_surrogate):
    optimizer = optim.Adam(model.parameters(), lr=training_params["lr"])
    optimizer_baseline = optim.Adam(baseline_model.parameters(), lr=training_params.baseline["lr"])
    criterion = torch.nn.CrossEntropyLoss().to(training_params["device"])
    criterion_baseline = torch.nn.CrossEntropyLoss().to(training_params["device"])

    for epoch in range(training_params.epoch):
        for batch_i, (X, y, params_1, params_2, perf, mf_hc, list_score) in tqdm(enumerate(dataloader_train),
                                                        total=len(dataloader_train)):
            global_variables.batch_idx += 1
            ############################### TRAIN ####################################
            model.train()
            baseline_model.train()

            optimizer.zero_grad()
            optimizer_baseline.zero_grad()

            X, y, params_1, params_2, perf, mf_hc, list_score = X.to(training_params["device"]).squeeze(0), \
                    y.to(training_params["device"]).squeeze(0), \
                    params_1.to(training_params["device"]).squeeze(0), \
                    params_2.to(training_params["device"]).squeeze(0), \
                    perf.to(training_params["device"]).squeeze(0), \
                    mf_hc.to(training_params["device"]).squeeze(0), \
                    list_score.to(training_params["device"]).squeeze(0)

            size_data = X.size(0)
            if training_params.extractor == "dida":
                y = y.argmax(2).float()

            assert not (torch.isnan(X).any() or torch.isinf(X).any())
            assert not (torch.isnan(y).any() or torch.isinf(y).any())
            assert not (torch.isnan(params_1).any() or torch.isinf(params_1).any())
            assert not (torch.isnan(params_2).any() or torch.isinf(params_2).any())
            assert not (torch.isnan(perf).any() or torch.isinf(perf).any())

            pred, dida_metafeatures = model(X, y, params_1, params_2)
            pred_baseline = baseline_model(mf_hc, params_1, params_2)

            loss = criterion(pred, perf.long())
            loss_baseline = criterion_baseline(pred_baseline, perf.long())

            pred_dida = pred.max(1, keepdim=True)[1]
            correct_dida = pred_dida.eq(perf.long().view_as(pred_dida)).sum().item()

            prediction_baseline = pred_baseline.max(1, keepdim=True)[1]
            correct_baseline = prediction_baseline.eq(perf.long().view_as(prediction_baseline)).sum().item()

            loss.backward()
            loss_baseline.backward()

            optimizer.step()
            optimizer_baseline.step()

            global_variables.add_scalars("Loss/training", v_dida=loss.item(), v_baseline=loss_baseline.item())
            global_variables.add_scalars("Accuracy/training", v_dida=(correct_dida / size_data), v_baseline=(correct_baseline / size_data))

            log.info("Epoch {} - Batch {} # {} {:.2f} ({:.2f} %) ! baseline {:.2f} ({:.2f} %)".format(epoch, batch_i, \
                                                        global_variables.name_model,
                                                        loss.item(), correct_dida/size_data,
                                                        loss_baseline.item(), correct_baseline/size_data))

        ############################### TEST ####################################
        global_variables.log.info("Evaluate on X_test")
        loss_test, loss_test_baseline, accuracy_dida, accuracy_baseline = evaluate_ranking(dataloader_test=dataloader_test,
                                                 model=model,
                                                 model_baseline=baseline_model,
                                                 training_params=training_params)

        # global_variables.add_scalar("Loss/test", v_dida=loss_test.item())
        # global_variables.add_scalar("Accuracy/test", v_dida=accuracy_dida)
        log.info("[TESTING] \t Accuracy {} {:.2f} <> baseline {}".format(global_variables.name_model,
                                                                            accuracy_dida, accuracy_baseline))

        with open(os.path.join(global_variables.working_dir, "results"), "a") as file:
            file.write("{},{},{},{},{}\n".format(epoch, loss_test, loss_test_baseline,
                                                accuracy_dida, accuracy_baseline))

        # if global_variables.current_best > loss_test:
        #     torch.save(model.state_dict(), os.path.join(global_variables.working_dir, "model"))
        #     global_variables.current_best = loss_test



@hydra.main(config_path="conf", config_name="ranking")
def main(cfg: DictConfig) -> None:
    log.info("Run with config: ")
    print(cfg)
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    global_variables.working_dir = os.getcwd()
    log.info("--------------------------------------------------------------------")
    network_params = cfg["dida"]
    training_params = cfg["training"]
    device = training_params["device"]

    log.info("Working dir {}".format(global_variables.working_dir))
    global_variables.writer = SummaryWriter(global_variables.working_dir)
    dump_all_python_files(global_variables.working_dir)

    log.info("Init model and baseline ...")

    if cfg.training.extractor == "dida":
        global_variables.name_model = cfg.training.extractor

        extractor = DIDA(d_Mfeat=cfg.dida.d_Mfeat,
                           d_Mlab=cfg.dida.d_Mlab, N=30,
                           nmoments=cfg.dida.nmoments,
                           d_out=cfg.dida.d_out,
                           fc_metafeatures=cfg.dida.fc_metafeatures,
                           tensorizations=cfg.dida.tensorizations,
                           dropout_fc=cfg.dida.dropout_fc,
                           writer=global_variables.writer)
    elif cfg.training.extractor == "dss":
        global_variables.name_model = "dss_{}".format(cfg.dss.version_dss_block)

        extractor = DSS(list_dim_output_phi=cfg.dss.list_dim_output_phi,
                        list_dim_output_rho=cfg.dss.list_dim_output_rho,
                        fc_metafeatures=cfg.dss.fc_metafeatures,
                        dropout_fc=cfg.dss.dropout_fc,
                        version_dss_block=cfg.dss.version_dss_block)
    elif cfg.training.extractor == "d2v":
        extractor = D2V_ours(fc_metafeatures=cfg.dida.fc_metafeatures,
                        dropout_fc=cfg.dida.dropout_fc)
    else:
        raise Exception("MF extractor not found")

    net = NetRanking(metafeatures_extractor=extractor,
                parameters=network_params,
                size_hparams=cfg.training.classifier.size,
                writer=global_variables.writer).to(device)
    # global_variables.model_cpu = NetRanking(parameters=network_params,
    #             size_hparams=cfg.training.classifier.size,
    #             writer=global_variables.writer)
    # global_variables.model_cpu.share_memory()
    baseline_linear = Baseline_linear_ranking(size_mf_hc=cfg.handcrafted_mf.size,
                                        size_hparams=cfg.training.classifier.size).to(device)

    # Prepare dataloader
    log.info("Load dataloader ...")
    dataloader_train, dataloader_test, dataloader_surrogate = prepare_data(cfg)
    # torch.save(dataloader_train, os.path.join(global_variables.working_dir, "dataloader_train"))
    # torch.save(dataloader_test, os.path.join(global_variables.working_dir, "dataloader_test"))

    log.info("Start training ...")
    training(dataloader_train, net, baseline_linear, training_params, dataloader_test, global_variables.writer, dataloader_surrogate)

def prepare_data(cfg):
    list_X, list_y, train_idx, test_idx = get_list_dataset_cc18(test_size=cfg.training.dataloader.test_size,
                                                            seed=cfg.training.seed)
    log.info("Load {} datasets: train {} - test {}".format(len(list_X), len(train_idx), len(test_idx)))
    list_X_train, list_y_train, list_X_test, list_y_test = [list_X[i] for i in train_idx], \
                                                            [list_y[i] for i in train_idx], \
                                                            [list_X[i] for i in test_idx], \
                                                            [list_y[i] for i in test_idx]
    return get_dataloader_ranking(list_X_train, list_y_train, list_X_test, list_y_test, cfg)


if __name__ == "__main__":
    main()
