import os
import torch
from torch.nn import functional as F
import warnings
from utils import setup_logger, set_random_seed, save_dict, load_dict, Notes, read_json, RSRLogger
from xtransfer.core import MatchingNet
from xtransfer.config import get_cfg_default
from target_datasets import create_dataloader
from utils.ResourceProfile import RP, stop_RP
from xtransfer.tools import replace_relu
from xtransfer.trans import Finetuner

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
warnings.filterwarnings("ignore")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def extract_backbone_feature(backbone, dataloader):
    for data, label in dataloader:
        x = backbone(data)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        out = x.view(x.size(0), -1)
        out = out.detach().cpu().numpy()
    return out, label


def load_data_to_memory(dataloader):
    for data, label in dataloader:
        out = data.detach().cpu().numpy()
        label = label
    return out, label


def get_fit_schema(fit_mode):
    if fit_mode == 'bi':  
        fit_schema = {'prehead': 'Bi_Contrast', 'afterhead': 'Bi_Contrast', 'trans': None}
    elif fit_mode == 'before': 
        fit_schema = {'prehead': 'Contrast', 'afterhead': None, 'trans': None}
    elif fit_mode == 'repair':
        fit_schema = {'prehead': 'Repair', 'pruner': 'Pruner', 'afterhead': None, 'trans': None}
    elif fit_mode == 'repair_noP':
        fit_schema = {'prehead': 'Repair', 'pruner': None, 'afterhead': None, 'trans': None}
    elif fit_mode == 'native':
        fit_schema = {'prehead': 'Native', 'afterhead': None, 'trans': None}
    elif fit_mode == 'og':  
        fit_schema = {'prehead': None, 'afterhead': None, 'trans': None}
    return fit_schema


def get_currect_size(model_name):
    resnet18 = ['miniImageNet', 'miniDomainNet', 'caltech', 'office31', 'officeHome', 'VoxCeleb']
    resnet10 = ['CIFAR', 'CUB', 'DTD', 'Omniglot', 'QuickDraw']
    conv4 = ['mnist', 'mnist_m', 'svhn', 'syn', 'usps']
    conv41d = ['MHEALTH', 'OPPORTUNITY', 'PAMAP2', 'sEMG', 'UniMiB']
    bert = ['News_bert']
    resnet181d = ['News']
    if model_name in resnet18:
        return 224, 'resnet18'
    elif model_name in resnet10:
        return 84, 'resnet10'
    elif model_name in conv4:
        return 32, 'conv4'
    elif model_name in conv41d:
        return 100, 'conv41d'
    elif model_name in resnet181d:
        return 100, 'resnet181d'
    elif model_name in bert:
        return 512, 'bert'
    else:
        raise ValueError('This model is not supported yet!')


def main():
    
    # load config from config file
    cfg = get_cfg_default()
    multi_or_single = 'multi' if len(cfg.MODEL_POOL.NAMES) > 1 else ('single-' + cfg.MODEL_POOL.NAMES[0])
    cfg.DATALOADER.RESIZE, backbone_model = get_currect_size(cfg.MODEL_POOL.NAMES[0])

    # dir setup
    output_dir = os.path.join(cfg.OUTPUT_DIR,
                              f"{}-{}_{}PCA_{}Mode_{}Shots_{}_{}_{}".format(cfg.DATALOADER.DATA_NAME,
                                                                        cfg.DATALOADER.TRANS_METHOD, cfg.xtransfer.PCA_COMPONENT,
                                                                        cfg.DATALOADER.MODE, cfg.DATALOADER.NUM_SHOTS,
                                                                        cfg.DATALOADER.EPO_ID,
                                                                        multi_or_single, backbone_model))

    # setup logger
    setup_logger(output_dir)
    tlogger = RSRLogger(output_dir)
    print('Data Configuration:')
    print(cfg.DATALOADER)
    RP(useGPU=True, filename=output_dir, interval=1)
    # setup seed
    set_random_seed(seed=cfg.SEED)
    # schema
    fit_schema = get_fit_schema("repair")
    
    # dataloader
    train, val, test, users = create_dataloader(data_name=cfg.DATALOADER.DATA_NAME, resize=cfg.DATALOADER.RESIZE,
                                                n_shot=cfg.DATALOADER.NUM_SHOTS,
                                                trans_method=cfg.DATALOADER.TRANS_METHOD,
                                                mode=cfg.DATALOADER.MODE, num_workers=cfg.DATALOADER.NUM_WORKERS,
                                                seed=cfg.DATALOADER.SEED, return_validation=True, epo_idx=epo_id,
                                                regression=cfg.xtransfer.REGRESSION, sbp=cfg.xtransfer.SBP, users=None)

    # train
    mNet = MatchingNet(cfg, trainloader=train, testloader=test, valloader=val, fit_schema=fit_schema, logger=tlogger)
    mNet.fit()

    # finetune
    if cfg.xtransfer.FINETUNE:
        torch.use_deterministic_algorithms(False)
        tr = iter(val)
        tr_x, tr_y = next(tr)
        finetuner = Finetuner(x=tr_x, y=tr_y, model=model, step=100, backnet=mNet, logger=tlogger,
                              regression=cfg.xtransfer.REGRESSION, is_1d=cfg.xtransfer.Conv1D)
        finetuner.optimize_params()
        linear = finetuner.get_linear()
        torch.save(linear, os.path.join(output_dir, 'classifier.pt'))

    # close loggers
    tlogger.close()
    stop_RP()

    # save model
    mNet.backbone.eval()
    model = mNet.backbone
    replace_relu(model)
    torch.save(model, os.path.join(output_dir, 'ft_model.pt'))

if __name__ == "__main__":
    main()