import copy
import functools
import os

import blobfile as bf
import torch as th
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
import numpy as np
import wandb

from . import dist_util, logger
from .fp16_util import MixedPrecisionTrainer
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler

# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0


class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
        loss_a_list=[0.99],
        use_wandb="False",
        gater_net_threshold=0.49999,
        if_finetune=False,
        gater_path="",
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps
        ######################################################
        # DDDM 额外添加的东西
        self.loss_a_list = loss_a_list
        self.use_wandb = use_wandb
        self.gater_net_threshold = gater_net_threshold
        self.if_finetune = if_finetune
        if self.if_finetune:
            from .unet import FNN
            assert gater_path != ""
            self.gater_head = FNN(dims=[self.model.t_spli_num, 64, len(self.model.expand_scale_list)], threshold=self.model.threshold)
            self.gater_head.load_state_dict(th.load(gater_path, map_location="cpu"))
            self.gater_head.to(dist_util.dev())
            self.gater_head.eval()
        ######################################################

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.if_finetune:
            self.mp_trainer = MixedPrecisionTrainer(
                model=self.model,
                use_fp16=self.use_fp16,
                fp16_scale_growth=fp16_scale_growth,
            )
        else:
            self.mp_trainer = MixedPrecisionTrainer(
                model=self.model.gater_head,
                use_fp16=self.use_fp16,
                fp16_scale_growth=fp16_scale_growth,
            )

        self.opt = AdamW(
            self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
        )
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.mp_trainer.master_params)
                for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            if self.if_finetune:
                strict_load = True
            else:
                strict_load = False
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
            self.model.load_state_dict(
                dist_util.load_state_dict(
                    resume_checkpoint, map_location=dist_util.dev()
                ),
                strict=strict_load
            )

        dist_util.sync_params(self.model.parameters())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.mp_trainer.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev()
                )
                ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def run_loop(self, single_label=-1):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            if single_label != -1:
                y = single_label*th.ones(cond['y'].shape, dtype=cond['y'].dtype, device=cond['y'].device)
                cond['y'] = y.type(cond['y'].dtype)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if (self.step % self.save_interval == 0) and (self.step != 0):
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        took_step = self.mp_trainer.optimize(self.opt)
        if took_step:
            self._update_ema()
        self._anneal_lr()
        self.log_step()

    def forward_backward(self, batch, cond):
        logger.log("The first 5 labels used in this step is", cond['y'][:5])
        self.mp_trainer.zero_grad()
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            if self.if_finetune:
                gates = get_gate(self.gater_head, t, self.model.t_spli_num, self.model.steps).detach()
                gates = expand_gates(gates, self.model.expand_scale_list, t, self.model.t_spli_num, self.model.steps)
                expand_scale_list = None
            else:
                gates = None
                expand_scale_list = self.model.expand_scale_list

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
                gater_net_threshold=self.gater_net_threshold,
                expand_scale_list=expand_scale_list,
                if_finetune=self.if_finetune,
                gates=gates,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            ##################################################
            if not self.if_finetune:
                # 这里用来处理损失函数
                diffusion_a = self.loss_a_list[0]
                size_a = 1 - diffusion_a
                sum_loss_a = diffusion_a + size_a
                loss = (diffusion_a/sum_loss_a) * loss + (size_a/sum_loss_a) * (losses["size"] * weights).mean()

                if self.use_wandb == "True" and dist.get_rank() == 0:
                    wandb.log({"Max gate size": losses["size"].max()})
                    wandb.log({"Min gate size": losses["size"].min()})
                    wandb.log({"Diffusion loss": losses["loss"].mean()})
                    wandb.log({"MSE loss": losses["mse"].mean()})
                    wandb.log({"VB loss": losses["vb"].mean()})
            else:
                if self.use_wandb == "True" and dist.get_rank() == 0:
                    wandb.log({"Diffusion loss": losses["loss"].mean()})
            ##################################################
            self.mp_trainer.backward(loss)
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )

    def _update_ema(self):
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.mp_trainer.master_params, rate=rate)

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self.mp_trainer.master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_snewest.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.mp_trainer.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt_newest.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()


def parse_resume_step_from_filename(filename):
    """
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")
    if len(split) < 2:
        return 0
    split1 = split[-1].split(".")[0]
    try:
        return int(split1)
    except ValueError:
        return 0


def get_blob_logdir():
    # You can change this to be a separate path to save checkpoints to
    # a blobstore or some external drive.
    return logger.get_dir()


def find_resume_checkpoint():
    # On your infrastructure, you may want to override this to automatically
    # discover the latest checkpoint on your blob storage, etc.
    return None


def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None


def log_loss_dict(diffusion, ts, losses):
    for key, values in losses.items():
        logger.logkv_mean(key, values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile = int(4 * sub_t / diffusion.num_timesteps)
            logger.logkv_mean(f"{key}_q{quartile}", sub_loss)


def get_scales(gater_head, t_spli_num):
    t_groups = th.tensor(list(range(t_spli_num)))
    one_hot_ts = nn.functional.one_hot(t_groups, num_classes=t_spli_num)
    # print("确认one hot", one_hot_ts, "在train_utile get)Scales中")

    with th.no_grad():
        gater_head.eval()
        gates = gater_head(one_hot_ts.float().to(dist_util.dev()))
        # print("get_scales输出的gates", gates)

        sum_used_nums = th.sum(gates, dim=0)

        return list(sum_used_nums.cpu().numpy())


def get_gate(gater_head, timesteps, t_spli_num, steps):
        import torch.nn as nn
        feature_t = (timesteps * t_spli_num / steps).type(th.int64)
        one_hot_ts = nn.functional.one_hot(feature_t, num_classes=t_spli_num)
        with th.no_grad():
            gater_head.eval()
            gates = gater_head(one_hot_ts.float().to(dist_util.dev()))

        # print("确认是不是二值向量", gates)
        # print("看看t", timesteps)
        return gates


def expand_gates(gates, expand_scale_list, t, t_spli_num, total_steps):
    assert len(expand_scale_list) == gates.shape[1]
    # print("确认要扩张的gates是不是二值", gates)

    target_gates = th.zeros((gates.shape[0], np.sum(expand_scale_list)), device=gates.device)
    for i in range(gates.shape[0]):
        scale = t[i] / total_steps

        p = 0
        layer_id = -1
        for exp_scale in expand_scale_list:
            layer_id += 1
            start_id = p
            p += exp_scale

            location = int(scale * exp_scale)
            target_gates[i, start_id+location] = gates[i, layer_id]

    # print("target_gates", target_gates)
    # print("t", t)
    # print("exp_list", expand_scale_list)

    return target_gates


def gates2index(gates, t, expand_scale_list, total_steps, se=True):
    # print("确认要转index的是不是二值向量", gates)

    assert len(expand_scale_list) == gates.shape[1]

    layer_idx_list = []
    for i in range(gates.shape[0]):
        scale = t[i] / total_steps

        indexes = []
        layer_id = -1
        for exp_scale in expand_scale_list:
            layer_id += 1
            location = int(scale * exp_scale)
            if gates[i, layer_id] == 1:
                indexes.append([location])
            else:
                indexes.append([])

        assert len(indexes) == len(expand_scale_list)
        layer_idx_list.append(indexes)

    # print("t", t)
    # print("expand_scale_list", expand_scale_list)
    # print("gates", gates)
    # print("倒序的index list")
    # for i in layer_idx_list:
    #     print(i)
    
    if se:
        sequential_layer_idx_list = []
        for i in range(len(layer_idx_list)):
            sequential_layer_idx_list.append(layer_idx_list[-i-1])
        return sequential_layer_idx_list

    else:
        return layer_idx_list


'''
def get_gate_(gater_head, timesteps, t_spli_num, steps, expand_scale_list, threshold):
        import torch.nn as nn
        feature_t = (timesteps * t_spli_num / steps).type(th.int64)
        one_hot_t = nn.functional.one_hot(feature_t, num_classes=t_spli_num)
        with th.no_grad():
            gater_head.eval()
            gates = gater_head(one_hot_t.float())
            len_gate = gates.shape[1]
            p = 0
            for exp_scale in expand_scale_list:
                start_id = p
                end_id = start_id + exp_scale
                p = end_id

                for i in range(gates.shape[0]):
                    scale_ = th.max(gates[i, start_id: end_id])

                    if scale_ > threshold:
                        gates[i, start_id: end_id] = gates[i, start_id: end_id] // scale_
                    else:
                        gates[i, start_id: end_id] = gates[i, start_id: end_id] * 0
            
            assert p == len_gate

        gates = gates.int()
        io_expand_scale = expand_scale_list[0]
        for i in range(gates.shape[0]):
            if gates[i, 0:io_expand_scale].sum() == 0:
                t_scale = timesteps[i] / steps
                location = int(t_scale * io_expand_scale)
                gates[i, location] = 1
            if gates[i, -io_expand_scale:].sum() == 0:
                t_scale = timesteps[i] / steps
                location = int(t_scale * io_expand_scale)
                gates[i, -(io_expand_scale-location)] = 1

        return gates
'''