import torch
import sys
sys.path.append("..")
import argparse

from guided_diffusion.train_util import get_scales
from guided_diffusion.unet import FNN
from guided_diffusion import dist_util, logger


def create_argparser():
	parser = argparse.ArgumentParser()
	parser.add_argument(
		"--model_path", type=str, default="models/static_skip_256x256_diffusion.pt",
	)
	parser.add_argument(
		"--expand_scale", type=float, default=0.5,
	)
	parser.add_argument(
		"--expand_scale_io", type=int, default=2,
	)
	parser.add_argument(
		"--expand_interval", type=str, default="3, 8, 11, 15",
	)
	parser.add_argument(
		"--skip_1024layers", type=str, default="11, 7",
	)
	parser.add_argument(
		"--t_spli_num", type=int, default=10,
	)
	parser.add_argument(
		"--gater_path", type=str, default="",
	)
	parser.add_argument(
		"--if_finetune", type=str, default="True",
	)
	parser.add_argument(
		"--threshold", type=float, default=0.49999,
	)

	return parser


class Pointer:
	def __init__(self):
		self.p = 0
		self.num1 = 0
		self.num2 = 0


def main():
	args = create_argparser().parse_args()
	dist_util.setup_dist()
	##########################################
	pointer = Pointer()
	pointer.p, pointer.num1, pointer.num2 = 0, 0, 0
	if args.if_finetune == "True":
		fnn_ckpt = torch.load(args.gater_path, map_location="cpu")
		out_size = 0
		for k in fnn_ckpt.keys():
			out_size = int(fnn_ckpt[k].shape[0])

		fnn = FNN(dims=[args.t_spli_num, 64, out_size], threshold=args.threshold)
		fnn.load_state_dict(
			fnn_ckpt,
		)
		fnn.to(dist_util.dev())
		fnn.eval()
		layer_used_num_list = get_scales(fnn, args.t_spli_num)

		trained_scales = []
		for i in range(len(layer_used_num_list)):
			trained_scales.append(int(layer_used_num_list[i] * args.expand_scale))
			if trained_scales[i] == 0:
				trained_scales[i] += 1
	##########################################

	'''
	expand_interval, skip_1024layers = args.expand_interval.replace(" ", ""), args.skip_1024layers.replace(" ", "")
	exp_intervals_, skips_ = expand_interval.split(","), skip_1024layers.split(",")
	exp_intervals, skips = [], []
	for i in exp_intervals_:
		exp_intervals.append(int(i))
	for i in skips_:
		skips.append(int(i))
	'''
	# ori_ckpt = torch.load(args.model_path)
	skip_ckpt = torch.load(args.model_path)

	def get_exp_scale(num1, num2):
		if args.if_finetune == "True":
			if (num1 == pointer.num1) and (num2 == pointer.num2):
				return trained_scales[pointer.p]
			else:
				pointer.p += 1
				pointer.num1 = num1
				pointer.num2 = num2
				return trained_scales[pointer.p]

		else:
			return 1


	'''
	# 先处理跳层的问题
	for k,v in ori_ckpt.items():
		ks = k.split(".")
		if ks[0] == "input_blocks":
			if int(ks[1]) >= skips[0]:
				continue
		if ks[0] == "output_blocks":
			if int(ks[1]) < skips[1]:
				continue
		
		if ks[0] == "middle_block":
			continue
		
		if ks[0] == "output_blocks":
			ks[1] = str(int(ks[1]) - skips[1])
			new_k = ".".join(ks)
			skip_ckpt[new_k] = v
			continue

		skip_ckpt[k] = v
	'''


	expand_ckpt = {}
	# 处理扩容的问题
	for k,v in skip_ckpt.items():
		ks = k.split(".")

		if "in_layers.2" in k:
			expand_scale = get_exp_scale(int(ks[1]), int(ks[2]))

			first_name = ".".join(ks[:3]) + ".unit_res_blocks."
			last_name = ".in_layers.0." + ks[-1]
			for i in range(expand_scale):
				new_k = first_name + str(i) + last_name
				expand_ckpt[new_k] = v
			continue

		if ("emb_layers" in k) or ("out_layers" in k):
			expand_scale = get_exp_scale(int(ks[1]), int(ks[2]))

			first_name = ".".join(ks[:3]) + ".unit_res_blocks."
			last_name = "." + ".".join(ks[-3:])
			for i in range(expand_scale):
				new_k = first_name + str(i) + last_name
				expand_ckpt[new_k] = v
			continue

		if "in_layers.0" in k:
			new_k = ".".join(ks[:3]) + ".layers_before_upd.0." + ks[-1]
			expand_ckpt[new_k] = v
			continue

		if "skip_connection" in k:
			pass

		if k.startswith("out."):
			first_name = "out.unit_conv_blocks."
			last_name = "." + ".".join(ks[-2:])
			if args.if_finetune == "True":
				assert pointer.p == len(trained_scales) - 2
				pointer.p += 1
				expand_scale = trained_scales[pointer.p]
				pointer.p -= 1
			else:
				expand_scale = 1
			for i in range(expand_scale):
				new_k = first_name + str(i) + last_name
				expand_ckpt[new_k] = v
			continue

		if "input_blocks.0" in k:
			expand_scale = get_exp_scale(int(ks[1]), int(ks[2]))

			first_name = "input_blocks.0.unit_conv_blocks."
			last_name = "." + ks[-1]
			for i in range(expand_scale):
				new_k = first_name + str(i) + last_name
				expand_ckpt[new_k] = v
			continue

		if ("1.norm" in k) or ("1.qkv" in k) or ("1.proj_out" in k):
			expand_scale = get_exp_scale(int(ks[1]), int(ks[2]))

			first_name = ".".join(ks[:3]) + ".att_blocks."
			last_name = "." + ".".join(ks[-2:])

			for i in range(expand_scale):
				new_k = first_name + str(i) + last_name
				expand_ckpt[new_k] = v

			continue

		expand_ckpt[k] = v

	if args.if_finetune == "True":
		assert pointer.p == len(trained_scales) - 2
		torch.save(expand_ckpt, "models/256x256_dddm_step2.pt")
	else:
		torch.save(expand_ckpt, "models/256x256_dddm_step1.pt")


if __name__ == "__main__":
	main()
