import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dida")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dataloader")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/models")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
#########################################################################################
import os
import torch
import logging
import numpy as np
from numpy.linalg import matrix_rank
np.random.seed(42)

from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, listconfig

import hydra
import global_variables
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import hparams
from dss import DSS
from dida_network import DIDA
from deep_sets import DeepSets
from d2v_ours import D2V_ours
from handcrafted import Handcrafted
from network import NN_patch_identification
from  utils_main import dump_all_python_files, count_parameters

from utils_dataloader import get_d2v_dataloader

torch.backends.cudnn.deterministic = True
# A logger for this file
global_variables.log = logging.getLogger(__name__)
log = global_variables.log
global_variables.batch_idx = 0

def training(dataloader_train, model, training_params, dataloader_test):
    optimizer = optim.Adam(model.parameters(), lr=training_params["lr"])
    criterion = torch.nn.CrossEntropyLoss().to(training_params["device"])

    early_stopping = 0
    best_accuracy = 0

    for epoch in range(training_params.epoch):
        list_scores = []
        for batch_i, (list_X_1, list_Y_1, list_X_2, list_Y_2, I) in tqdm(enumerate(dataloader_train), total=len(dataloader_train)):
            global_variables.batch_idx += 1
            ############################### TRAIN ####################################
            model.train()
            optimizer.zero_grad()

            list_X_1, list_Y_1, list_X_2, list_Y_2, I = list_X_1.to(training_params["device"]).squeeze(0), \
                    list_Y_1.to(training_params["device"]).squeeze(0), \
                    list_X_2.to(training_params["device"]).squeeze(0), \
                    list_Y_2.to(training_params["device"]).squeeze(0), \
                    I.to(training_params["device"]).squeeze(0)

            if training_params.extractor == "dida":
                list_Y_1 = list_Y_1.argmax(2).float()
                list_Y_2 = list_Y_2.argmax(2).float()

            size_data = list_X_1.size(0)

            output = model(list_X_1, list_Y_1, list_X_2, list_Y_2)
            loss = criterion(output, I.long())

            pred = output.max(1, keepdim=True)[1]
            correct = pred.eq(I.long().view_as(pred)).sum().item()
            list_scores.append(correct / size_data)

            loss.backward()
            optimizer.step()

            global_variables.add_scalar("Loss/training", v_dida=loss.item())
            global_variables.add_scalar("Accuracy/training", v_dida=(correct / size_data))

            log.info("Epoch {} - Batch {} \t {} {:.2f} ({:.2f} %)".format(epoch, batch_i, \
                                                        training_params.extractor, loss.item(), correct/size_data))

        if np.mean(list_scores) > best_accuracy:
            best_accuracy = np.mean(list_scores)
            early_stopping = 0
            # torch.save(model.state_dict(), os.path.join(global_variables.working_dir, "model"))
        else:
            early_stopping += 1
            if early_stopping > 100:
                evaluate(dataloader_test, model, training_params)
                break

        evaluate(dataloader_test, model, training_params)



def evaluate(dataloader_test, model, training_params):
    list_loss, list_scores = [], []
    with torch.no_grad():
        criterion = torch.nn.CrossEntropyLoss().to(training_params["device"])
        for batch_i, (list_X_1, list_Y_1, list_X_2, list_Y_2, I) in tqdm(enumerate(dataloader_test), total=len(dataloader_test)):

            list_X_1, list_Y_1, list_X_2, list_Y_2, I = list_X_1.to(training_params["device"]).squeeze(0), \
                    list_Y_1.to(training_params["device"]).squeeze(0), \
                    list_X_2.to(training_params["device"]).squeeze(0), \
                    list_Y_2.to(training_params["device"]).squeeze(0), \
                    I.to(training_params["device"]).squeeze(0)

            if training_params.extractor == "dida":
                list_Y_1 = list_Y_1.argmax(2).float()
                list_Y_2 = list_Y_2.argmax(2).float()

            size_data = list_X_1.size(0)

            output = model(list_X_1, list_Y_1, list_X_2, list_Y_2)
            loss = criterion(output, I.long())

            pred = output.max(1, keepdim=True)[1]
            correct = pred.eq(I.long().view_as(pred)).sum().item() / size_data
            list_scores.append(correct)
            list_loss.append(loss.item())

    log.info("Test scores loss: {}, accuracy {}".format(np.mean(list_loss), np.mean(list_scores)))
    with open(os.path.join(global_variables.working_dir, "results"), "a") as file:
        file.write("{0},{1},{2}\n".format(
            training_params.epoch, np.mean(list_loss), np.mean(list_scores)))



@hydra.main(config_path="conf", config_name="patch")
def main(cfg: DictConfig) -> None:
    log.info("Run with config: ")
    print(cfg)
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    global_variables.working_dir = os.getcwd()
    log.info("--------------------------------------------------------------------")
    network_params = cfg["dida"]
    training_params = cfg["training"]
    device = training_params["device"]

    log.info("Working dir {}".format(global_variables.working_dir))
    global_variables.writer = SummaryWriter(global_variables.working_dir)

    log.info("Init model and baseline ...")
    if cfg.training.extractor == "dida":
        extractor = DIDA(d_Mfeat=cfg.dida.d_Mfeat,
                           d_Mlab=cfg.dida.d_Mlab, N=30,
                           nmoments=cfg.dida.nmoments,
                           d_out=cfg.dida.d_out,
                           fc_metafeatures=cfg.dida.fc_metafeatures,
                           tensorizations=cfg.dida.tensorizations,
                           dropout_fc=cfg.dida.dropout_fc,
                           writer=None)
    elif cfg.training.extractor == "dss":
        extractor = DSS(list_dim_output_phi=cfg.dss.list_dim_output_phi,
                        list_dim_output_rho=cfg.dss.list_dim_output_rho,
                        fc_metafeatures=cfg.dss.fc_metafeatures,
                        dropout_fc=cfg.dss.dropout_fc,
                        version_dss_block=cfg.dss.version_dss_block)
    elif cfg.training.extractor == "handcrafted":
        extractor = Handcrafted(fc_metafeatures=cfg.handcrafted.fc_metafeatures,
                        dropout_fc=cfg.handcrafted.dropout_fc)
    elif cfg.training.extractor == "deep_sets":
        extractor = DeepSets(list_dim_output_phi=cfg.deep_sets.list_dim_output_phi,
                        list_dim_output_rho=cfg.deep_sets.list_dim_output_rho,
                        fc_metafeatures=cfg.deep_sets.fc_metafeatures,
                        dropout_fc=cfg.deep_sets.dropout_fc)
    elif cfg.training.extractor == "d2v":
        extractor = D2V_ours(fc_metafeatures=cfg.deep_sets.fc_metafeatures,
                        dropout_fc=cfg.deep_sets.dropout_fc)
    else:
        raise Exception("MF extractor not found")

    net = NN_patch_identification(metafeatures_extractor=extractor).to(device)

    log.info("NB Parameters {}".format(count_parameters(net.metafeatures_extractor)))

    # Prepare dataloader
    log.info("Load dataloader ...")
    dataloader_train, dataloader_test = get_d2v_dataloader(cfg)

    log.info("Start training ...")
    training(dataloader_train, net, training_params, dataloader_test)



if __name__ == "__main__":
    main()
