import torch
from torch import nn
from collections import OrderedDict

from utils.utils import NestedTensor, nested_tensor_from_tensor_list
from models.backbone import build_backbone
from models.transformer import build_transformer
import os

os.environ["TORCH_HUB_OFFLINE"] = "1"


class brain_encoder(nn.Module):
    def __init__(self, args, dataset):
        super().__init__()

        self.lr_backbone = args.lr_backbone
        self.device = args.device

        self.backbone_arch = args.backbone_arch
        self.return_interm = args.return_interm
        self.encoder_arch = args.encoder_arch

        ### Brain encoding model
        # if args.encoder_arch == 'transformer':
        self.backbone_model = build_backbone(args)

        self.transformer = build_transformer(args)

        self.num_queries = (
            sum(dataset.num_parcels)
            if isinstance(dataset.num_parcels, list)
            else dataset.num_parcels
        )
        self.hidden_dim = self.transformer.d_model
        self.linear_feature_dim = self.hidden_dim

        self.enc_layers = args.enc_layers
        self.dec_layers = args.dec_layers

        self.lh_vs = args.lh_vs
        self.rh_vs = args.rh_vs

        self.query_embed = nn.Embedding(self.num_queries, self.hidden_dim)

        ### backbone_arch for feature exraction
        self.single_backbone = True

        if ("resnet" in self.backbone_arch) and ("transformer" in self.encoder_arch):
            self.input_proj = nn.Conv2d(
                self.backbone_model.num_channels, self.hidden_dim, kernel_size=1
            )
        elif ("resnet" in self.backbone_arch) and ("linear" in self.encoder_arch):
            self.input_proj = nn.AdaptiveAvgPool2d(1)
            self.linear_feature_dim = self.backbone_model.num_channels

        # linear readout layers to the neural data
        self.readout_res = args.readout_res

        # this is a mask of shape (num_parcels, num_voxels) where each row is the voxels that belong in a parcel
        self.valid_voxel_mask = dataset.valid_voxel_mask.to(args.device)
        self.parcel_mask = torch.zeros(dataset.num_hemi_voxels, self.num_queries).to(
            args.device
        )
        self.voxel_map = torch.zeros_like(
            dataset.valid_voxel_mask, dtype=torch.int64
        )  # translates from possibly invalid voxel index to valid voxel index
        self.voxel_map[dataset.valid_voxel_mask] = torch.arange(dataset.num_hemi_voxels)
        all_parcels = (
            dataset.parcels[0] + dataset.parcels[1]
            if len(dataset.parcels) == 2
            else dataset.parcels
        )
        for i, parcel in enumerate(all_parcels):
            parcel = self.voxel_map[parcel]
            self.parcel_mask[parcel, i] = 1

        self.embed = nn.Sequential(nn.Linear(self.hidden_dim, dataset.num_hemi_voxels))

    def to_device(self, device):
        """Recursively move all torch objects to a specified device"""
        self.device = device
        for attr_name, attr_value in self.__dict__.items():
            if isinstance(attr_value, torch.Tensor):
                setattr(self, attr_name, attr_value.to(device))
            elif isinstance(attr_value, torch.nn.Module):
                attr_value.to(device)

    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)

        if self.single_backbone:
            if self.lr_backbone == 0:
                with torch.no_grad():
                    features, pos = self.backbone_model(samples)
            else:
                features, pos = self.backbone_model(samples)

            input_proj_src, mask = features[-1].decompose()
            # assert mask is not None
            pos_embed = pos[-1]
            _, _, h, w = pos_embed.shape
        else:
            features = []
            pos = []
            for backbone in self.backbones:
                with torch.no_grad():
                    fs, ps = backbone(samples)
                    features.append(fs)
                    pos.append(ps)

            input_proj_srcs = []
            masks = []
            DIM_TO_CONCAT = 2
            for feature in features:
                input_proj_src, mask = feature[-1].decompose()
                input_proj_srcs.append(input_proj_src.flatten(2))
                masks.append(mask.flatten(2))
            pos_embeds = [p[-1].flatten(2) for p in pos]
            input_proj_src = torch.cat(input_proj_srcs, dim=DIM_TO_CONCAT).unsqueeze(-1)
            mask = torch.cat(masks, dim=DIM_TO_CONCAT).unsqueeze(-1)
            pos_embed = torch.cat(pos_embeds, dim=DIM_TO_CONCAT).unsqueeze(-1)
        hs = self.transformer(
            input_proj_src,
            mask,
            self.query_embed.weight,
            pos_embed,
            self.return_interm,
        )
        output_tokens = hs[-1]  # TODO: 250 x 768 output tokens

        # output tokens: [batch_size, num_parcels, hidden_dim] like (bs, 500, 768)
        pred = self.embed(output_tokens)
        pred = torch.movedim(pred, 1, -1)
        pred = pred * self.parcel_mask
        pred = torch.sum(pred, dim=-1)
        a = torch.zeros(
            (len(pred), len(self.valid_voxel_mask)), dtype=torch.float32
        ).to(self.device)
        a[:, self.valid_voxel_mask] = pred
        out = {
            # "lh_f_pred": lh_f_pred,
            # "rh_f_pred": rh_f_pred,
            "pred": a,
            "output_tokens": output_tokens,
        }

        return out
