import torch
import os
from scipy.stats import truncnorm
import numpy as np
import random


def get_femb_noise(noise_batch, noise_channel, noise_height, noise_width):
    noise_path = f'../data/noise/{noise_batch}-{noise_channel}-{noise_height}-{noise_width}'
    if os.path.exists(os.path.join(noise_path, 'noise.pt')):
        noise = torch.load(os.path.join(noise_path, 'noise.pt'), \
            map_location=torch.device('cpu'))
        # print(('noise has been loaded'))
    else:
        mu,std,lower,upper = 125,125,0,255
        noise = torch.from_numpy(
                scale(truncnorm((lower-mu)/std,(upper-mu)/std, 
                loc=mu, scale=std).rvs((noise_batch,noise_channel,noise_height,noise_width))))
        os.makedirs(noise_path, exist_ok=True)
        torch.save(noise, os.path.join(noise_path, 'noise.pt'))
        # print(f'noise has been created and saved at {noise_path}')
        # print('file saved ({})'.format(os.path.join(noise_path, 'noise.pt')))
    return noise


def scale(x):
    return x.astype(np.float32)/255.


def func_embedding(teacher):
    avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
    noise = get_femb_noise(1, 3, 64, 64).cuda()
    teacher.cuda()
    teacher.eval()
    with torch.no_grad():
        features, soft_targets = teacher(noise, get_features=True)
        femb = avgpool(features[4])
        femb = femb.view(1, -1)
    return femb.detach()


def func_embedding1(teacher):
    avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
    noise = get_femb_noise(1, 3, 64, 64).cuda()
    teacher.cuda()
    teacher.eval()
    femb = []
    with torch.no_grad():
        features, soft_targets = teacher(noise, get_features=True)
        for i in range(5):
            femb.append(avgpool(features[i]).view(1, -1).detach())
    return femb


def func_embedding2(teacher):
    avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
    noise = get_femb_noise(3, 3, 64, 64).cuda()
    teacher.cuda()
    femb = []
    with torch.no_grad():
        features, soft_targets = teacher(noise, get_features=True)
        features = avgpool(features[4]) # 3, 256, 1, 1
        femb = [feature.view(1, -1) for feature in features]
        # femb = [avgpool(feature).view(1, -1) for feature in list(features.values())]
    return femb

def func_embedding3(teacher, inp):
    avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
    teacher.cuda()
    with torch.no_grad():
        features, soft_targets = teacher(inp, get_features=True)
        features = avgpool(features[4])
        features.squeeze()
        features = torch.mean(features, dim=0).view(1, -1)
        breakpoint()
    return features.detach()


class ResNetArchEncoder:

    def __init__(self, args, image_size_list=None, depth_list=None,width_mult_list=None):
        
        self.image_size_list = [64] if image_size_list is None else image_size_list
        self.depth_list = [1, 2, 3, 4, 5] if depth_list is None else depth_list
        self.width_mult_list = [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.] if width_mult_list is None else width_mult_list
        
        self.tc_stage_depth = args.tc_stage_depth
        self.tc_stage_num = args.tc_stage_num

        # self.tc_stage_depth = 5
        # self.tc_stage_num = 4

        """" build info dict """
        self.n_dim = 0
        # resolution
        self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
        self._build_info_dict(target='r')
        # width_mult
        self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
        self._build_info_dict(target='width_mult')

    @property
    def n_stage(self):
        return self.tc_stage_num

    @property
    def max_n_blocks(self):
        return self.n_stage * max(self.depth_list)

    def _build_info_dict(self, target):
        if target == 'r':
            target_dict = self.r_info
            target_dict['L'].append(self.n_dim)
            for img_size in self.image_size_list:
                target_dict['val2id'][img_size] = self.n_dim
                target_dict['id2val'][self.n_dim] = img_size
                self.n_dim += 1
            target_dict['R'].append(self.n_dim)
        elif target == 'width_mult':
            target_dict = self.width_mult_info
            choices = self.width_mult_list
            for i in range(self.max_n_blocks):
                target_dict['val2id'].append({})
                target_dict['id2val'].append({})
                target_dict['L'].append(self.n_dim)
                for w in choices:
                    target_dict['val2id'][i][w] = self.n_dim
                    target_dict['id2val'][i][self.n_dim] = w
                    self.n_dim += 1
                target_dict['R'].append(self.n_dim)
            

    def arch2feature(self, arch_dict):
        depth_list, width_mult_list, r = arch_dict['depth_list'], arch_dict['width_mult_list'], arch_dict['image_size']

        feature = np.zeros(self.n_dim)

        feature[self.r_info['val2id'][r]] = 1

        start_pt = 0
        for i, depth in enumerate(depth_list):
            for j in range(start_pt, start_pt + depth):
                feature[self.width_mult_info['val2id'][j][width_mult_list[i][j-start_pt]]] = 1
            start_pt += max(self.depth_list)

        return feature



    
