import copy
import gc
import os
import sys

from torch import nn
sys.path.append(".")
import warnings
import copy
import random
import logging
import functools
import time
from typing import Callable, Dict, Any, List

from omegaconf import OmegaConf
import torchvision

from lib_ddif.utils import build_tensor_dataset
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
sys.path.append(".")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("lib.utils").setLevel(logging.WARNING)
logging.getLogger('lib.gaussian.gaussianimage_cholesky').setLevel(logging.WARNING)

import hydra
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision.utils import save_image

from DM.utils import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, set_seed, save_and_print, get_images
from DM.hyper_params import load_default
from lib.utils import get_initialized_gs_batch
from lib.gaussian.strategy_batch import DefaultStrategy
from lib.gaussian.gs_utils import create_optimizers_and_schedulers


logger = logging.getLogger(__name__)


def move_module_all_tensors_to_cpu(module: torch.nn.Module):
    
    module = module.to("cpu")

    
    for name, value in vars(module).items():
        if torch.is_tensor(value) and value.device.type != "cpu":
            setattr(module, name, value.cpu())
        elif isinstance(value, (list, tuple)):
            new_list = []
            changed = False
            for v in value:
                if torch.is_tensor(v) and v.device.type != "cpu":
                    new_list.append(v.cpu())
                    changed = True
                else:
                    new_list.append(v)
            if changed:
                if isinstance(value, tuple):
                    setattr(module, name, tuple(new_list))
                else:
                    setattr(module, name, new_list)
        elif isinstance(value, dict):
            new_dict = {}
            changed = False
            for k, v in value.items():
                if torch.is_tensor(v) and v.device.type != "cpu":
                    new_dict[k] = v.cpu()
                    changed = True
                else:
                    new_dict[k] = v
            if changed:
                setattr(module, name, new_dict)
    return module


class MultiMethodOutputHook:
    """Wrap dict output[key] transform。"""

    def __init__(
        self,
        module: torch.nn.Module,
        methods: List[str],
        transform: Callable,
        key: str = "render",
    ):
        self.module = module
        self.methods = methods
        self.transform = transform
        self.key = key
        self._orig = {}  

        for name in methods:
            orig = getattr(module, name, None)
            if orig is None or not callable(orig):
                raise AttributeError(
                    f"{module.__class__.__name__} has no callable '{name}'"
                )

            self._orig[name] = orig

            
            setattr(module, name, self._make_wrapper(orig))

    def _make_wrapper(self, orig_fn: Callable):
        @functools.wraps(orig_fn)
        def wrapper(*args, **kwargs):
            out = orig_fn(*args, **kwargs)
            
            if not isinstance(out, dict):
                return out
            if self.key not in out:
                return out  

            x = out[self.key]
            if hasattr(self.transform, "to"):
                try:
                    self.transform.to(x.device)
                except Exception:
                    pass
            y = self.transform(x)

            new_out = dict(out)
            new_out[self.key] = y
            return new_out

        return wrapper

    def remove(self):
        """""
        for name, fn in self._orig.items():
            setattr(self.module, name, fn)
        self._orig.clear()


def attach_output_transform_to_methods(
    module: torch.nn.Module,
    train_transform: Callable,
    methods: List[str],
    key: str = "render",
) -> MultiMethodOutputHook:
    """
   。
    - {'render': [B,C,H,W], ...} dict
    - train_transform: torchvision.Normalize Kornia ZCA
    """
    return MultiMethodOutputHook(module, methods, train_transform, key)


class Tee(object):
    def __init__(self, *files):
        self.files = files

    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush()

    def flush(self):
        for f in self.files:
            f.flush()


def setup_logging_for_distributed(is_master: bool):
    """
   。
    (is_master=True)。
   。
    """
    logger = logging.getLogger()  

    if is_master:
        
        
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s [%(levelname)s] %(message)s",
            handlers=[logging.StreamHandler()],  
        )
        logger.info("Logging is configured for the master process.")
    else:
        
        
        logger.setLevel(logging.CRITICAL + 1)
        
        logger.addHandler(logging.NullHandler())
        
        
        
        


def check_args(args):
    check_keys = {
        "use_clamp",
        "boundary_loss_type",
        "boundary_loss_lambda",
    }
    for key in check_keys:
        if key not in args:
            raise ValueError(f"Missing required argument: {key}")

    
    if not hasattr(args, "min_start_epoch"):
        logger.warning("min_start_epoch not found in args, setting to default 0.")
        args.min_start_epoch = 0

    if not hasattr(args, "forward_subset"):
        logger.warning("forward_subset not found in args, setting to default False.")
        args.forward_subset = False


class EmbedWrapper(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
    def forward(self, x):
        
        return self.backbone.embed(x)



@hydra.main(config_path="configs", config_name="imagenette", version_base="1.3")
def main(args):
    OmegaConf.set_readonly(args, False)
    OmegaConf.set_struct(args, False)

    check_args(args)

    save_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    os.makedirs(save_dir, exist_ok=True)
    args.save_path = save_dir

    logger.info("Setting up file logging for rank 0.")
    file_name = os.path.splitext(os.path.basename(__file__))[0]
    err_path = os.path.join(args.save_path, f"{file_name}.log")
    err_file = open(err_path, "a", buffering=1)
    sys.stderr = Tee(sys.__stderr__, err_file)
    
    args.log_path = err_path

    set_seed(args.seed)
    args = load_default(args)

    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(f"{args.save_path}/imgs", exist_ok=True)

    
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = args.device
    args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
    dsa_params = ParamDiffAug()


    if args.dsa:
        
        args.dc_aug_param = None

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")  

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv, zca_trans = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)
    args.channel, args.im_size, args.num_classes, args.mean, args.std = channel, im_size, num_classes, mean, std
    args.im_size = im_size

    if not args.zca:
        train_transform = torchvision.transforms.Normalize(mean=mean, std=std)
    else:
        train_transform = zca_trans
    if args.zca:
        zca_trans = zca_trans
        zca_trans_cpu = copy.deepcopy(zca_trans).to('cpu')
        zca_trans_cpu = move_module_all_tensors_to_cpu(zca_trans_cpu)
    else:
        zca_trans = None
        zca_trans_cpu = None


    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
    if not hasattr(args, "eval_it") or args.eval_it == 0:
        eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] 
    else:
        eval_it_pool = np.arange(args.eval_it, args.Iteration+1, args.eval_it).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] 
    
    


    save_and_print(args.log_path, f"Hyper-parameters: {args.__dict__}")
    save_and_print(args.log_path, f"Evaluation model pool: {model_eval_pool}")
    save_and_print(args.log_path, OmegaConf.to_yaml(args, resolve=True))


    if args.boundary_loss_type is not None:
        if args.boundary_loss_type == "exponent":
            def boundary_loss(x):
                exponent = args.boundary_loss_exponent
                assert exponent > 0 and exponent % 2 == 0, "Exponent must be a positive even integer."
                return torch.mean(x.pow(exponent))

        elif args.boundary_loss_type == "log":
            def boundary_loss(x, epsilon=1e-6):
                clamped_xy = torch.clamp(x, min=-1.0 + epsilon, max=1.0 - epsilon)
                log_barrier = -torch.log(1.0 - clamped_xy.pow(2))
                return torch.mean(log_barrier)
        else:
            raise ValueError(f"Unsupported boundary loss type: {args.boundary_loss_type}")
    else:
        def boundary_loss(x):
            return torch.tensor(0.0, device=x.device)
    

    for exp in range(args.num_exp):
        save_and_print(args.log_path, f'\n================== Exp {exp} ==================\n ')
        save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')

        ''' organize the real dataset '''
        save_and_print(args.log_path, "BUILDING DATASET")

        images_all, labels_all = build_tensor_dataset(dst_train, batch_size=args.batch_real, workers=args.workers, class_map=class_map)
        

        indices_class = [[] for c in range(num_classes)]
        for i, lab in tqdm(enumerate(labels_all)):
            indices_class[lab].append(i)

        ''' initialize the synthetic data '''
        syn_labels = np.array([np.ones(args.gpc, dtype=np.int_)*i for i in range(num_classes)])
        syn_labels = torch.tensor(syn_labels, dtype=torch.long, requires_grad=False, device=device).view(-1) 
        gs_model, sample_index = get_initialized_gs_batch(num_classes, args.gaussian.batch_size, args.gs_dir, args.gs_type, args.gaussian, device=device, epochs=None)
        gs_model.requires_grad_(True)
        methods = ["forward", "forward_subset", "crop_forward_loop", "crop_forward_padding"]
        hook_handle = attach_output_transform_to_methods(
            gs_model, train_transform, methods, key="render"
        )

        optimizers, schedulers = create_optimizers_and_schedulers(gs_model, args)
        strategy = DefaultStrategy(args.strategy, gs_model, optimizers)

        scaler = torch.amp.GradScaler(enabled=(gs_model.precision == "fp16"))

        ''' training '''
        best_acc = {m: 0 for m in model_eval_pool}
        best_std = {m: 0 for m in model_eval_pool}

        save_and_print(args.log_path, '%s training begins'%get_time())

        

        for it in range(args.Iteration+1):
            save_this_it = False

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                
                gc.collect()
                torch.cuda.empty_cache()
                for model_eval in model_eval_pool:
                    save_and_print(args.log_path, '-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))

                    save_and_print(args.log_path, f'DSA augmentation strategy: {args.dsa_strategy}')
                    save_and_print(args.log_path, f'DSA augmentation parameters: {dsa_params.__dict__}')

                    accs_test = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)  
                        label_syn_eval = syn_labels
                        with torch.inference_mode():
                            image_syn_eval = gs_model()["render"]
                        
                        _, _, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args, dsa_param=dsa_params)
                        accs_test.append(acc_test)
                    accs_test = np.array(accs_test)
                    acc_test_mean = np.mean(accs_test)
                    acc_test_std = np.std(accs_test)
                    if acc_test_mean > best_acc[model_eval]:
                        best_acc[model_eval] = acc_test_mean
                        best_std[model_eval] = acc_test_std
                        save_this_it = True
                        torch.save({"best_acc": best_acc, "best_std": best_std}, f"{args.save_path}/best_performance.pt")
                    save_and_print(args.log_path, 'Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------' % (len(accs_test), model_eval, acc_test_mean, acc_test_std))
                    save_and_print(args.log_path, f"{args.save_path}")
                    save_and_print(args.log_path, f"{it:5d} | Accuracy/{model_eval}: {acc_test_mean}")
                    save_and_print(args.log_path, f"{it:5d} | Max_Accuracy/{model_eval}: {best_acc[model_eval]}")
                    save_and_print(args.log_path, f"{it:5d} | Std/{model_eval}: {acc_test_std}")
                    save_and_print(args.log_path, f"{it:5d} | Max_Std/{model_eval}: {best_std[model_eval]}")
                    del image_syn_eval, label_syn_eval

                ''' visualize and save '''
                save_name = os.path.join(f"{args.save_path}/imgs", 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = gs_model()["render"].detach()
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=10) 
                del image_syn_vis

                if save_this_it:
                    
                    save_dict = {
                        "model": gs_model,
                        "syn_lr": args.lr_net,
                    }
                    hook_handle.remove()
                    torch.save(save_dict, os.path.join(args.save_path, f"GSDD_TM_{args.ipc}ipc
                    hook_handle = attach_output_transform_to_methods(
                        gs_model, train_transform, methods, key="render"
                    )
                save_dict = {
                    "model": gs_model,
                    "syn_lr": args.lr_net,
                }
                hook_handle.remove()
                torch.save(save_dict, os.path.join(args.save_path, f"GSDD_TM_{args.ipc}ipc_iter{it}.pt"))
                hook_handle = attach_output_transform_to_methods(
                    gs_model, train_transform, methods, key="render"
                )
                
                gc.collect()
                torch.cuda.empty_cache()

            ''' Train synthetic data '''
            net = get_network(args.model, channel, num_classes, im_size, dist=False).to(device) 
            net.train()
            for param in list(net.parameters()):
                param.requires_grad = False
            
            embed = EmbedWrapper(net).to("cuda")
            if torch.cuda.device_count() > 1:
                embed = nn.DataParallel(embed)

            

            ''' update synthetic data '''
            loss = torch.tensor(0.0).to(args.device)
            for c in range(num_classes):
                loss_c = torch.tensor(0.0).to(args.device)

                img_real = get_images(images_all, indices_class, c, args.batch_real)
                img_real = img_real.to(args.device)

                if args.batch_syn > 0:
                    indices = np.random.permutation(range(c * args.gpc, (c + 1) * args.gpc))[:args.batch_syn]
                else:
                    indices = range(c * args.gpc, (c + 1) * args.gpc)

                img_syn_all = gs_model()["render"]
                img_syn = img_syn_all[indices]
                

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=dsa_params)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=dsa_params)

                output_real = embed(img_real).detach()
                output_syn = embed(img_syn)

                loss_c += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0)) ** 2)

                boundary_loss_value = boundary_loss(gs_model.params["xy"])
                loss_c += args.boundary_loss_lambda * boundary_loss_value

                for name in optimizers:
                    optimizers[name].zero_grad(set_to_none=True)

                scaler.scale(loss_c).backward()
                for name in optimizers:
                    scaler.unscale_(optimizers[name])
                strategy.step_pre_backward(it)

                
                for name in optimizers:
                    scaler.step(optimizers[name])
                    if args.scheduler_gs.type is not None:
                        schedulers[name].step()

                scaler.update()

                if args.use_clamp:
                    gs_model.clamp()

                strategy.step_post_backward(it)

                loss += loss_c

            loss_avg = loss.item()

            loss_avg /= (num_classes)

            if it%10 == 0:
                save_and_print(args.log_path, '%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

if __name__ == '__main__':
    main()
