import os
import json
import torch
import logging
import collections

from utils.registry_class import PRETRAIN
from ..modules.unet import UNetSD_T2VBase4Cross, UNetSD_T2VBase

import json


@PRETRAIN.register_function()
def load_model(
        unet=None,
        resume_checkpoint=None,
        from_modelscope=False,
        inference=False,
        key_file="key_file.json",
        **kwargs
    ):

    state_dict = torch.load(resume_checkpoint, map_location='cpu')
    
    
    if from_modelscope:
        dict= json.load(open(key_file))
        exchange = dict['exchange']
        exchange_keys = [e[0] for e in exchange]
        copy = dict['copy']
        copy_keys = [e[0] for e in copy]
        new_state_dict = {}
        for k,v in state_dict.items():
            new_k = k
            for keys in exchange_keys:
                if k.startswith(keys):
                    index = exchange_keys.index(keys)
                    new_k = k.replace(keys,exchange[index][1])
                    # print(f"exchange {k} to {new_k}")
                    break
            new_state_dict[new_k] = v    
                
        state_dict = new_state_dict
        new_state_dict = {}
        for k,v in state_dict.items():
            new_k = k
            for keys in copy_keys:
                if k.startswith(keys):
                    index = copy_keys.index(keys)
                    new_k = k.replace(keys,copy[index][1])
                    if new_k in state_dict.keys():
                        raise ValueError(f"key {new_k} already exists")
            new_state_dict[k] = v
            new_state_dict[new_k] = state_dict[k]

        state_dict = new_state_dict
        mkey, ukey = unet.load_state_dict(state_dict, strict=False)
        print(f'load unet with missing key: {mkey} and unexpected key: {ukey}')
    
    elif inference:
        ## this is for inference loading
        if isinstance(unet, UNetSD_T2VBase):
            mkey, ukey = unet.load_state_dict(state_dict, strict=True)
            print(f'load unet with missing key: {mkey} and unexpected key: {ukey}')
        elif isinstance(unet, UNetSD_T2VBase4Cross):
            unet_state_dict = state_dict['unet_state_dict']
            try:
                mkey, ukey = unet.load_state_dict(unet_state_dict, strict=True)
                print(f'load unet with missing key: {mkey} and unexpected key: {ukey}')
            except Exception as e:
                print(f'load unet with excpetion: {e}')
            # if "text_reweight_state_dict" in state_dict:
            #     text_reweight_state_dict = state_dict['text_reweight_state_dict']
            #     try:
            #         mkey, ukey = text_reweight.load_state_dict(text_reweight_state_dict, strict=True)
            #         print(f'load text reweight with missing key: {mkey} and unexpected key: {ukey}')
            #     except:
            #         print(f'load text reweight with excpetion: {e}')

    print(f'Successfully load model from {resume_checkpoint}')
    
    return unet
    # return unet, text_reweight



@PRETRAIN.register_function()
def load_model_new(
        unet=None, 
        resume_checkpoint=None,
        from_modelscope=False,
        inference=False,
        key_file="key_file.json",
        **kwargs
    ):

    state_dict = torch.load(resume_checkpoint, map_location='cpu')
    
    if isinstance(unet, UNetSD_T2VBase4Cross):
        if from_modelscope:
            dict= json.load(open(key_file))
            exchange = dict['exchange']
            exchange_keys = [e[0] for e in exchange]
            copy = dict['copy']
            copy_keys = [e[0] for e in copy]
            new_state_dict = {}
            for k,v in state_dict.items():
                new_k = k
                for keys in exchange_keys:
                    if k.startswith(keys):
                        index = exchange_keys.index(keys)
                        new_k = k.replace(keys,exchange[index][1])
                        # print(f"exchange {k} to {new_k}")
                        break
                new_state_dict[new_k] = v    
                    
            state_dict = new_state_dict
            new_state_dict = {}
            for k,v in state_dict.items():
                new_k = k
                for keys in copy_keys:
                    if k.startswith(keys):
                        index = copy_keys.index(keys)
                        new_k = k.replace(keys,copy[index][1])
                        if new_k in state_dict.keys():
                            raise ValueError(f"key {new_k} already exists")
                new_state_dict[k] = v
                new_state_dict[new_k] = state_dict[k]

            state_dict = new_state_dict
        mkey, ukey = unet.load_state_dict(state_dict, strict=False)
        print(f'load unet with missing key: {mkey} and unexpected key: {ukey}')
    
    elif isinstance(unet, UNetSD_T2VBase):
        mkey, ukey = unet.load_state_dict(state_dict, strict=True)
        print(f'load unet with missing key: {mkey} and unexpected key: {ukey}')
    print(f'Successfully load model from {resume_checkpoint}')
    return unet



@PRETRAIN.register_function()
def pretrain_from_sd():
    pass


@PRETRAIN.register_function()
def pretrain_ema_model():
    pass
