# -*- encoding: utf-8 -*-
# here put the import lib
import os
import sys
import math
import random

import numpy as np
import torch
import torch.nn.functional as F
from SwissArmyTransformer.resources import auto_create
from .direct_sr import DirectSuperResolution
from .iterative_sr import IterativeSuperResolution

class SRGroup:
    def __init__(self, args, home_path=None,):
        dsr_path = auto_create('cogview2-dsr', path=home_path)
        itersr_path = auto_create('cogview2-itersr', path=home_path)
        dsr = DirectSuperResolution(args, dsr_path)
        itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
        self.dsr = dsr
        self.itersr = itersr

    def sr_base(self, img_tokens, txt_tokens):
        assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
        batch_size = img_tokens.shape[0]
        txt_len = txt_tokens.shape[-1]
        if len(txt_tokens.shape) == 1:
            txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
        sred_tokens = self.dsr(txt_tokens, img_tokens)
        iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
        return iter_tokens[-batch_size:]
    
    # def sr_patch(self, img_tokens, txt_tokens):
    #     assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
    #     batch_size = img_tokens.shape[0] * 9
    #     txt_len = txt_tokens.shape[-1]
    #     if len(txt_tokens.shape) == 1:
    #         txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
    #     img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
    #     iter_tokens = self.sr_base(img_tokens, txt_tokens)
    #     return iter_tokens