import os
import sys
import torch
# from model.net.wideresnet import Wide_ResNet
from data.dl_getter import normalization_infos
from model.models import *
from model.model_io import load_model
from model.models.pc import PC, AE
from model.models.pc_dense import PCD
from model.models.pc_lbyl import PCL
from model.models.classifier import Classifier


def get_model(args):
    model =  _init_model(args)
    if args.eval:
        load_model(args, args.load_path, model=model)
    else:
        pass
    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     model = torch.nn.DataParallel(model)
    # .cuda() -> .to(args.device)
    model = model.to(args.device)
    return model


def _init_model(args):
    # ============ building network ... ============
    args.arch = args.arch.replace("deit", "vit")
    num_classes = args.num_labels if args.cls else None
    input_size = args.sh[-1]

    # in_norm = None
    # if args.in_norm:
        # norm_info = normalization_infos[args.dataset]
    #     in_norm = NormalizeInput(mean=norm_info[0], std=norm_info[1])
    at_norm = None
    if args.at:
        at_norm = ReNormalizeInput()
    if args.ae:
        model = AE(args)
    elif args.pcd:
        model = PCD(args)
    elif args.pcl:
        model = PCL(args)
    elif args.pc:
        model = PC(args)
    elif args.arch == 'fc':
        model = Classifier(args)
    elif 'resnet'== args.arch:
        raise NotImplemented
    else:
        print(f"Unknow architecture: {args.arch}")
        sys.exit(1)
    return model


class NormalizeInput(torch.nn.Module): ## transform으로 옮길 것

    def __init__(self, mean=(0.4914, 0.4822, 0.4465),
                 std=(0.2023, 0.1994, 0.2010)):
        super().__init__()
        r_mean = 2 * mean -1.
        r_std = 2 * std
        self.register_buffer('r_mean', torch.Tensor(r_mean).reshape(1, -1, 1, 1))
        self.register_buffer('r_std', torch.Tensor(r_std).reshape(1, -1, 1, 1))

    def forward(self, x):
        return (x - self.r_mean) / self.r_std


class ReNormalizeInput(torch.nn.Module):

    def __init__(self):
        super().__init__()
        r_mean = (.5, .5, .5)
        r_std = (.5, .5, .5)
        self.register_buffer('r_mean', torch.Tensor(r_mean).reshape(1, -1, 1, 1))
        self.register_buffer('r_std', torch.Tensor(r_std).reshape(1, -1, 1, 1))

    def forward(self, x):
        return (x - self.r_mean) / self.r_std
