import os
from selectors import EpollSelector
from types import new_class
import torch
from torch import nn
import torchvision
from embedding import func_embedding, func_embedding1, func_embedding2, func_embedding3
from embedding import ResNetArchEncoder
from embedding import get_femb_noise
from myutils import l2norm
from set_encoder.setenc_models import SetPool
from get_dataset import NUM_CLASSES
from collections import OrderedDict

from networks.modules import MetaLinear, MetaSequential, MetaModule
from networks.resnet_lit import get_resnet
from pr_resnet import PR

## Transformer Type Setencoder
class SetEncoder(MetaModule):
    def __init__(self, nz, num_sample, use_l2norm=True):
      super(SetEncoder, self).__init__()    
      self.nz = nz
      self.num_sample = num_sample
      self.dpath = DATAPATH
      self.intra_setpool = SetPool(dim_input=512, 
                                  num_outputs=1, 
                                  dim_output=self.nz, 
                                  dim_hidden=self.nz, 
                                  mode='sabPF')
      self.inter_setpool = SetPool(dim_input=self.nz, 
                                  num_outputs=1, 
                                  dim_output=self.nz, 
                                  dim_hidden=self.nz, 
                                  mode='sabPF')
      self.use_l2norm = use_l2norm

    def get_setdata(self, D):
      ds_name, ds_split = D['ds_name'], D['ds_split']
      if ds_name == 'tiny_imagenet' and ds_split is not None:
        ncls = 40
        x = torch.load(os.path.join(self.dpath, f'{ds_name}_{ds_split}_bylabel.pt'))
      else:
        ncls = NUM_CLASSES[ds_name]
        x = torch.load(os.path.join(self.dpath, f'{ds_name}_bylabel.pt'))
      data = []
      for cls in range(ncls):
          cx = x[cls][0]
          ridx = torch.randperm(len(cx))
          data.append(cx[ridx[:self.num_sample]])
      x = torch.stack(data)
      return x
        
    def forward(self, D, n, params=None):
      
      X = self.get_setdata(D).cuda()
      cls_protos = self.intra_setpool(
              X.view(-1, self.num_sample, 512), params=self.get_subdict(params, 'intra_setpool')).squeeze(1)
      set_enc = self.inter_setpool(cls_protos.unsqueeze(0), params=self.get_subdict(params, 'inter_setpool')).view(1, -1)
      set_enc = set_enc.repeat(n, 1)

      if self.use_l2norm: return l2norm(set_enc)
      else: return set_enc


class SetEncoder1(MetaModule):
  def __init__(self, nz, use_l2norm=True):
    super(SetEncoder1, self).__init__()    
    self.nz = nz
    self.use_l2norm = use_l2norm
    resnet18 = torchvision.models.resnet18(pretrained=True).eval()
    self.feature_extractor = torch.nn.Sequential(*list(resnet18.children())[:-1]).cuda()
    self.fc_set = MetaSequential(
                          MetaLinear(512, 256),
                          nn.Tanh(),
                          MetaLinear(256, self.nz)
                        )

  def get_query_embs(self, imgs):
        # imgs = D['ds_imgs'].shape = [1, 3, 64, 64]
        with torch.no_grad():
            query = self.feature_extractor(imgs.cuda())
        b, d, _, _ = query.size()
        query = query.view(b, d) #(b, 512)
        return query
        
  def forward(self, D, n, params=None):
    x = D['ds_imgs'].cuda()
    x = self.feature_extractor(x).view(-1, 512)
    set_enc = self.fc_set(x, params=self.get_subdict(params, 'fc_set'))
    set_enc = torch.mean(set_enc, dim=0)
    set_enc = set_enc.view(1, -1)
    set_enc = set_enc.repeat(n, 1)

    if self.use_l2norm: return l2norm(set_enc)
    else: return set_enc
  

class FuncEncoder(MetaModule):
  def __init__(self, inp_dim, out_dim, h_dim=128, use_l2norm=True):
    super(FuncEncoder, self).__init__()    
    self.inp_dim = inp_dim
    self.out_dim = out_dim
    self.use_l2norm = use_l2norm
    self.fc = MetaSequential(
        MetaLinear(self.inp_dim, h_dim),
        nn.Tanh(),
        MetaLinear(h_dim, self.out_dim)
      )
  
  def forward(self, F, D, n, params=None):
    ## F is tc net
    F = func_embedding(F) # (1, 256)
    F = F.repeat(n, 1)
    out = self.fc(F, params=self.get_subdict(params, 'fc'))
    if self.use_l2norm: return l2norm(out)
    else: return out

class FuncEncoder1(MetaModule):
  def __init__(self, inp_dim, out_dim, h_dim=128, use_l2norm=True):
    super(FuncEncoder1, self).__init__()    
    femb_dim=32
    self.use_l2norm = use_l2norm
    self.tc_stage_channel_widths = [32,64,128,256]
    self.fc_stem = MetaLinear(self.tc_stage_channel_widths[0], femb_dim)
    self.fc_stg0 = MetaLinear(self.tc_stage_channel_widths[0], femb_dim)
    self.fc_stg1 = MetaLinear(self.tc_stage_channel_widths[1], femb_dim)
    self.fc_stg2 = MetaLinear(self.tc_stage_channel_widths[2], femb_dim)
    self.fc_stg3 = MetaLinear(self.tc_stage_channel_widths[3], femb_dim)
    self.fc = MetaSequential(
        MetaLinear(femb_dim*5, h_dim),
        nn.Tanh(), 
        MetaLinear(h_dim, out_dim)
    )

  def forward(self, F, D, n, params=None):
    ## F is tc net
    femb = func_embedding1(F) # (1, 256)
    stem_h = self.fc_stem(femb[0], params=self.get_subdict(params, 'fc_stem'))
    stg0_h = self.fc_stg0(femb[1], params=self.get_subdict(params, 'fc_stg0'))
    stg1_h = self.fc_stg1(femb[2], params=self.get_subdict(params, 'fc_stg1'))
    stg2_h = self.fc_stg2(femb[3], params=self.get_subdict(params, 'fc_stg2'))
    stg3_h = self.fc_stg3(femb[4],params=self.get_subdict(params, 'fc_stg3'))
    inp = torch.cat((stem_h, stg0_h, stg1_h, stg2_h, stg3_h), dim=1)
    out = self.fc(inp, params=self.get_subdict(params, 'fc'))
    out = out.repeat(n, 1)
    if self.use_l2norm: return l2norm(out)
    else: return out


class FuncEncoder2(MetaModule):
  def __init__(self, inp_dim, out_dim, h_dim=128, use_l2norm=True):
    super(FuncEncoder2, self).__init__()    
    femb_dim=256
    self.inp_dim = inp_dim
    self.out_dim = out_dim
    self.use_l2norm = use_l2norm
    self.fc1 = MetaLinear(femb_dim, h_dim)
    self.fc2 = MetaLinear(femb_dim, h_dim)
    self.fc3 = MetaLinear(femb_dim, h_dim)

    self.fc = MetaSequential(
        MetaLinear(3*h_dim, h_dim),
        nn.Tanh(),
        MetaLinear(h_dim, self.out_dim)
      )
  
  def forward(self, F, D, n, params=None):
    ## F is tc net
    femb = func_embedding2(F) # (3, 256)
    noise1 = self.fc1(femb[0], params=self.get_subdict(params, 'fc1'))
    noise2 = self.fc2(femb[1], params=self.get_subdict(params, 'fc2'))
    noise3 = self.fc3(femb[2], params=self.get_subdict(params, 'fc3'))
    inp = torch.cat((noise1, noise2, noise3), dim=1)
    out = self.fc(inp, params=self.get_subdict(params, 'fc'))
    out = out.repeat(n, 1)
    if self.use_l2norm: return l2norm(out)
    else: return out


class FuncEncoder3(MetaModule):
  def __init__(self, inp_dim, out_dim, h_dim=128, use_l2norm=True):
    super(FuncEncoder3, self).__init__()    
    self.inp_dim = 256
    self.out_dim = out_dim
    self.use_l2norm = use_l2norm
    self.fc = MetaSequential(
        MetaLinear(self.inp_dim, h_dim),
        nn.Tanh(),
        MetaLinear(h_dim, self.out_dim)
      )
  
  def forward(self, F, D, n, params=None):
    with torch.no_grad():
      x = D['ds_imgs'].cuda()
      avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
      features = F.cuda()(x, get_features=True)[0][4]
      features = avgpool(features)
      features = torch.mean(features, dim=0)
      features = features.view(1, -1)
    ## F is tc net
    # breakpoint()
    # femb = func_embedding3(F, x)
    features = features.repeat(n, 1)
    out = self.fc(features, params=self.get_subdict(params, 'fc'))
    if self.use_l2norm: return l2norm(out)
    else: return out


class FuncEncoder4(MetaModule):
  '''
  use attention map as functional embeddding
  '''
  def __init__(self, args, f_inp_dim, f_out_dim, h_dim=64, use_l2norm=True):
    super(FuncEncoder4, self).__init__()    
    self.args = args
    # self.f_inp_dim = f_inp_dim
    self.f_inp_dim = 64
    self.f_out_dim = f_out_dim
    self.use_noise = args.use_noise
    self.pr_type = args.pr_type
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.use_l2norm = use_l2norm
    self.fc = MetaSequential(
        MetaLinear(self.f_inp_dim, h_dim),
        nn.Tanh(),
        MetaLinear(h_dim, self.f_out_dim)
      )
  
  def _get_attn_map(self, model, inp):
    model = model.to(self.device)
    inp = inp.to(self.device)
    model.eval()
    with torch.no_grad():
        features, soft_targets = model(inp, get_features=True)
        femb = torch.mean(features[4], dim=0).squeeze()
        femb  = torch.sum(femb, dim=0).squeeze() # 1, 256, 8, 8
        femb = nn.Flatten()(femb)
        femb = l2norm(femb)
        femb = femb.view(1, -1)
    return femb.detach()

  def forward(self, F, D, n, params=None):
    inp = D['ds_imgs']
    if self.use_noise:
      inp = get_femb_noise(1, 3, 64, 64).cuda()
    func_encoding = self._get_attn_map(F, inp)
    func_encodings = func_encoding.repeat(n, 1)
    out = self.fc(func_encodings, params=self.get_subdict(params, 'fc'))
    if self.use_l2norm: return l2norm(out)
    else: return out


class ArchEncoder(MetaModule):
  def __init__(self, inp_dim, out_dim, args, h_dim=64, 
              imsz=64, tc_stage_channel_widths=[32, 64, 128, 256],
              use_l2norm=True):
    super(ArchEncoder, self).__init__()    
    self.inp_dim = inp_dim
    self.out_dim = out_dim
    # self.fc = MetaSequential(
    #     MetaLinear(self.inp_dim, h_dim),
    #     nn.Tanh(),
    #     MetaLinear(h_dim, self.out_dim)
    #   )
    self.fc = MetaLinear(self.inp_dim, self.out_dim)
    self.imsz = imsz
    self.tc_stage_channel_widths = tc_stage_channel_widths
    self.resnet_encoder = ResNetArchEncoder(args)
    self.use_l2norm = use_l2norm
    
  def forward(self, A, params=None):
    ## A is st_net_info
    arch_onehots = []
    for depth_config, channel_widths in zip(A['depth_config'], A['channel_widths']):
      ## switch width to width_mult
      width_mult = []
      for tc_width, width in zip(self.tc_stage_channel_widths, channel_widths):
          width_mult.append([float(w/tc_width) for w in width])

      arch_dict = {'depth_list': depth_config,
                  'width_mult_list': width_mult,
                  'image_size': self.imsz}
      arch_onehot = torch.from_numpy(self.resnet_encoder.arch2feature(arch_dict)).cuda()
      arch_onehot = arch_onehot.type(torch.float32).detach()
      arch_onehots.append(arch_onehot)
      # arch_onehot = arch_onehot.view(1, -1)
    arch_onehots = torch.stack(arch_onehots)
    out = self.fc(arch_onehots, params=self.get_subdict(params, 'fc'))
    if self.use_l2norm: return l2norm(out)
    else: return out


class ModelEncoder(MetaModule):
  def __init__(self, args, a_inp_dim, f_inp_dim, m_inp_dim, pr_type, 
              use_noise, use_attnmap, h_dim=64, use_l2norm=True):
    super(ModelEncoder, self).__init__()
    self.args = args
    self.a_inp_dim = a_inp_dim
    self.f_inp_dim = f_inp_dim
    self.m_inp_dim = m_inp_dim
    self.use_l2norm = use_l2norm
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.pr_type = pr_type
    self.imsz = args.image_size
    self.tc_stage_channel_widths = [32, 64, 128, 256]
    self.resnet_encoder = ResNetArchEncoder(args)
    self.use_l2norm = use_l2norm
    self.use_noise = use_noise
    self.use_attnmap = use_attnmap # if not, just avgpooling the last stage output feature
    if self.use_attnmap:
      self.f_inp_dim = 64
    else:
      self.f_inp_dim = 256
    self.fc_a = MetaLinear(self.a_inp_dim, h_dim)
    self.fc_f = MetaLinear(self.f_inp_dim, h_dim)
    self.fc = MetaLinear(int(2*h_dim), self.m_inp_dim)

  def _get_arch_encoding(self, depth_config, channel_widths):
    width_mult = []
    for tc_width, width in zip(self.tc_stage_channel_widths, channel_widths):
        width_mult.append([float(w/tc_width) for w in width])

    arch_dict = {'depth_list': depth_config,
                'width_mult_list': width_mult,
                'image_size': self.imsz}
    arch_encoding = torch.from_numpy(self.resnet_encoder.arch2feature(arch_dict)).cuda()
    arch_encoding = arch_encoding.type(torch.float32)
    return arch_encoding
  
  def _get_func_encoding(self, model, inp):
    with torch.no_grad():
      model = model.to(self.device)
      inp = inp.to(self.device)
      avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
      features, soft_targets = model(inp, get_features=True)
      feature = avgpool(features[4])
      feature = torch.mean(feature, dim=0)
      feature = feature.view(1, -1)
    return feature.detach()
  
  def _get_attn_map(self, model, inp):
    model = model.to(self.device)
    inp = inp.to(self.device)
    model.eval()
    with torch.no_grad():
        features, soft_targets = model(inp, get_features=True)
        femb = features[4]
        if self.args.use_abs:
          femb = torch.abs(femb)
        femb = torch.mean(femb, dim=0).squeeze()
        femb  = torch.sum(femb, dim=0).squeeze() # 1, 256, 8, 8
        femb = nn.Flatten()(femb)
        femb = l2norm(femb)
        femb = femb.view(1, -1)
    return femb.detach()
  
  def _get_student_model(self, F, n_classes, depth_config, channel_widths):
    student_model = get_resnet(n_classes, 
                            depth_config=depth_config, 
                            channel_widths=channel_widths, 
                            stage_strides=[1, 2, 2, 2], 
                            tc_stage_channel_widths=[32, 64, 128, 256])
    student_model.conv1.load_state_dict(F.conv1.state_dict())
    student_model.bn1.load_state_dict(F.bn1.state_dict())
    student_model.fc.load_state_dict(F.fc.state_dict())
    n_stage = 4
    param_remapper = PR(device=self.device, n_stage=n_stage, tc_net=F, 
                        st_net=student_model,st_depth_config=depth_config,
                        st_channel_widths=channel_widths, pr_type=self.pr_type, args=self.args)
    st_stages = [student_model.layer1, student_model.layer2, student_model.layer3, student_model.layer4]
    st_dict_lists = param_remapper.param_remapping()
    for i in range(n_stage):
        for d in range(depth_config[i]):
            st_stages[i][d].load_state_dict(st_dict_lists[i][d])
    return student_model


  def forward(self, D, F, A, n, is_student=True, params=None):
    inp = D['ds_imgs']
    if self.use_noise:
      inp = get_femb_noise(1, 3, 64, 64).cuda()
    
    n_classes = NUM_CLASSES[D['ds_name']]
    if D['ds_name'] == 'tiny_imagenet':
      n_classes = 40 ## split
    
    if is_student:
      arch_encodings = []
      func_encodings = []
      for depth_config, channel_widths in zip(A['depth_config'], A['channel_widths']):
        student_model = self._get_student_model(F, n_classes, depth_config, channel_widths)
        arch_encoding = self._get_arch_encoding(depth_config, channel_widths)
        arch_encodings.append(arch_encoding.view(-1))
        if self.use_attnmap:
          func_encoding = self._get_attn_map(student_model, inp)
        else:
          func_encoding = self._get_func_encoding(student_model, inp)
        func_encodings.append(func_encoding.view(-1))
      arch_encodings = torch.stack(arch_encodings, dim=0)
      func_encodings = torch.stack(func_encodings, dim=0)
    else: ## for teacher
      tc_depth_config = [5, 5, 5, 5]
      tc_channel_widths = [[_] * 5 for _ in self.tc_stage_channel_widths]
      arch_encoding = self._get_arch_encoding(tc_depth_config, tc_channel_widths)
      arch_encodings = arch_encoding.repeat(n, 1)
      if self.use_attnmap:
          func_encoding = self._get_attn_map(F, inp)
      else:
        func_encoding = self._get_func_encoding(F, inp)
      func_encodings = func_encoding.repeat(n, 1)

    arch_encodings = self.fc_a(arch_encodings, params=self.get_subdict(params, 'fc_a'))
    func_encodings = self.fc_f(func_encodings, params=self.get_subdict(params, 'fc_f'))
    out = torch.cat([arch_encodings, func_encodings], 1)
    if self.use_l2norm:
      out = l2norm(out)
    out = self.fc(out, params=self.get_subdict(params, 'fc'))
    if self.use_l2norm: return l2norm(out)
    else: return out


class PredictorModel(MetaModule):
  def __init__(self, args):
    super(PredictorModel, self).__init__()
    self.args = args
    self.input_type = args.input_type
    self.h_dim = args.h_dim
    self.proj_inp_dim = 0
    self.use_l2norm = args.use_l2norm
    ## temp
    self.func_type = args.func_type
    self.set_type = args.set_type
    self.use_attnmap = args.use_attnmap
    self.use_noise = args.use_noise

    if 'D' in self.input_type: # dataset encoder
      self.nz = args.nz
      self.num_sample = args.num_sample
      self.proj_inp_dim += self.nz
      if self.set_type == 0:
        self.set_encoder = SetEncoder(self.nz, self.num_sample,
                                  use_l2norm=self.use_l2norm)
      elif self.set_type == 1:
        self.set_encoder = SetEncoder1(self.nz,
                                  use_l2norm=self.use_l2norm) 

    if 'F' in self.input_type: # function embedding for teacher
      self.f_inp_dim = args.f_inp_dim
      self.f_out_dim = args.f_out_dim
      self.proj_inp_dim += self.f_out_dim
      if self.func_type == 0:
        self.func_encoder = FuncEncoder(self.f_inp_dim, self.f_out_dim, 
                                          use_l2norm=self.use_l2norm)
      elif self.func_type == 1:
        self.func_encoder = FuncEncoder1(self.f_inp_dim, self.f_out_dim, 
                                          use_l2norm=self.use_l2norm)
      elif self.func_type == 2:
        self.func_encoder = FuncEncoder2(self.f_inp_dim, self.f_out_dim, 
                                          use_l2norm=self.use_l2norm)
      elif self.func_type == 3:
        self.func_encoder = FuncEncoder3(self.f_inp_dim, self.f_out_dim, 
                                          use_l2norm=self.use_l2norm)
      elif self.func_type == 4:
        self.func_encoder = FuncEncoder4(self.args, self.f_inp_dim, self.f_out_dim, 
                                        h_dim=64, use_l2norm=self.use_l2norm)
    
    if 'A' in self.input_type: # student architecture encoding
      self.a_inp_dim = args.a_inp_dim
      self.a_out_dim = args.a_out_dim
      self.proj_inp_dim += self.a_out_dim
      self.arch_encoder = ArchEncoder(self.a_inp_dim, self.a_out_dim, 
                                        self.args, use_l2norm=self.use_l2norm)

    if ('T' in self.input_type) or ('S' in self.input_type): # model encoder
      self.f_inp_dim = args.f_inp_dim
      self.a_inp_dim = args.a_inp_dim
      self.m_inp_dim = args.m_inp_dim
      self.pr_type = args.pr_type
      self.use_noise = args.use_noise
      self.use_attnmap = args.use_attnmap
      self.model_encoder = ModelEncoder(self.args, self.a_inp_dim, self.f_inp_dim,
                                            self.m_inp_dim, self.pr_type, self.use_noise,
                                            self.use_attnmap, use_l2norm=self.use_l2norm)
    
    if 'T' in self.input_type:
      self.proj_inp_dim += self.m_inp_dim

    if 'S' in self.input_type:
      self.proj_inp_dim += self.m_inp_dim
      
    if self.proj_inp_dim == 0: raise ValueError(self.proj_inp_dim)

    self.proj_layers = MetaSequential(
        MetaLinear(self.proj_inp_dim, self.h_dim),
        nn.Tanh(),
        MetaLinear(self.h_dim, 1)
      )

  def forward(self, D=None, F=None, A=None, n=None, params=None):
    input_vec = []
    if 'D' in self.input_type:
      input_vec.append(self.set_encoder(D, n, params=self.get_subdict(params, 'set_encoder')))
    if 'F' in self.input_type:
      input_vec.append(self.func_encoder(F, D, n, params=self.get_subdict(params, 'func_encoder')))
    if 'A' in self.input_type:
      input_vec.append(self.arch_encoder(A, params=self.get_subdict(params, 'arch_encoder')))
    if 'T' in self.input_type:
      input_vec.append(self.model_encoder(D, F, A, n, is_student=False, params=self.get_subdict(params, 'model_encoder')))
    if 'S' in self.input_type:
      input_vec.append(self.model_encoder(D, F, A, n, is_student=True, params=self.get_subdict(params, 'model_encoder')))
    input_vec = torch.cat(input_vec, dim=1)
    return self.proj_layers(input_vec, params=self.get_subdict(params, 'proj_layers'))


if __name__ == "__main__":
  proj_inp_dim = 10
  h_dim = 10

  proj = MetaSequential(
          MetaLinear(proj_inp_dim, h_dim),
          nn.Tanh(),
          MetaLinear(h_dim, 1)
        )


  pred.eval()
  print(pred.requires_grad)

