import math
import sys
from typing import Iterable, Optional

import torch
import torch.nn.functional as F
import numpy as np

from timm.utils import ModelEma

import my_utils as utils


def get_next_cond(c_pos_indices, z_pos_indices):  ###
    # find next position of condition
    next_ids = torch.searchsorted(c_pos_indices[:, :, 0].contiguous(), z_pos_indices[:, :, 0].contiguous(),
                                  right=True)  ### right=True时左闭右开
    next_ids[next_ids == c_pos_indices.shape[1]] = c_pos_indices.shape[1] - 1
    next_ids = next_ids[:, :, None].expand(-1, -1, 3)

    next_cond_pos = torch.gather(c_pos_indices, dim=1, index=next_ids)  ### 按照上面得到的index对c_pos_indices进行复制

    if (z_pos_indices.shape[1] == 0):
        return z_pos_indices.clone()

    # print(next_cond_pos)
    return next_cond_pos


def get_extra_indices(Xct_pos, Xbd_pos):
    # L_c, L_z = Xct_pos.shape[1], Xbd_pos.shape[1]

    Xct_extra = Xct_pos.clone()
    Xbd_extra = get_next_cond(Xct_pos[:, :-1, :], Xbd_pos)
    extra_indices = torch.cat([Xct_extra, Xbd_extra], axis=1)
    # print("!",extra_indices)
    return extra_indices


def random_mask(quan_centers, encodings, device):
    center_end = torch.tensor([[[256, 256, 256]]])
    latent_end = torch.tensor([[1024]])

    center_end = center_end.expand(quan_centers.shape[0], 1, -1).to(device, non_blocking=True)
    latent_end = latent_end.expand(encodings.shape[0], 1).to(device, non_blocking=True)

    max_num = quan_centers.shape[1]
    select_num = np.random.randint(1, max_num)
    selected_ind = np.sort(np.random.choice(max_num, select_num, replace=False))

    # quan_centers = quan_centers[:, selected_ind, ...]
    # encodings = encodings[:, selected_ind]

    quan_centers = torch.cat([quan_centers[:, selected_ind, ...], center_end], axis=1)
    encodings = torch.cat([encodings[:, selected_ind], latent_end], axis=1)

    return quan_centers, encodings, selected_ind


def sort(centers_quantized, encodings):
    ind3 = torch.argsort(centers_quantized[:, :, 2], dim=1)
    centers_quantized = torch.gather(centers_quantized, 1, ind3[:, :, None].expand(-1, -1, centers_quantized.shape[-1]))
    encodings = torch.gather(encodings, 1, ind3)

    _, ind2 = torch.sort(centers_quantized[:, :, 1], dim=1, stable=True)
    centers_quantized = torch.gather(centers_quantized, 1, ind2[:, :, None].expand(-1, -1, centers_quantized.shape[-1]))
    encodings = torch.gather(encodings, 1, ind2)

    _, ind1 = torch.sort(centers_quantized[:, :, 0], dim=1, stable=True)
    centers_quantized = torch.gather(centers_quantized, 1, ind1[:, :, None].expand(-1, -1, centers_quantized.shape[-1]))
    encodings = torch.gather(encodings, 1, ind1)
    return centers_quantized, encodings


def train_batch(model, vqvae, Xbd, Xct, criterion, device):
    # def train_batch(model, vqvae, surface, criterion):
    with torch.no_grad():
        _, _, Xct_centers_quantized, _, _, Xct_encodings = vqvae.encode(Xct)
        _, _, Xbd_centers_quantized, _, _, Xbd_encodings = vqvae.encode(Xbd)

    Xct_centers_quantized, Xct_encodings = sort(Xct_centers_quantized, Xct_encodings)
    Xbd_centers_quantized, Xbd_encodings = sort(Xbd_centers_quantized, Xbd_encodings)

    Xct_centers_quantized, Xct_encodings, selected_ind = random_mask(Xct_centers_quantized, Xct_encodings, device)

    # selected_pos = torch.searchsorted(Xbd_centers_quantized[:, :, 0], Xct_centers_quantized[:, :-1, 0])    # B × L_cond
    # selected_pos = selected_pos - torch.ones_like(selected_pos)
    selected_pos = len(selected_ind)

    centers_quantized = torch.cat((Xct_centers_quantized, Xbd_centers_quantized), dim=1)
    encodings = torch.cat((Xct_encodings, Xbd_encodings), dim=1)

    extra = get_extra_indices(Xct_centers_quantized, Xbd_centers_quantized)  ### 这里可能有点问题

    x_logits, y_logits, z_logits, latent_logits = model(centers_quantized, extra, encodings,
                                                        selected_pos)  ### 这里通过自回归模型来分别预测点坐标以及对应latent的codebook index
    # x_logits, y_logits, z_logits = model(centers_quantized, extra, selected_pos)

    loss_x = criterion(x_logits, Xbd_centers_quantized[:, :, 0])
    loss_y = criterion(y_logits, Xbd_centers_quantized[:, :, 1])
    loss_z = criterion(z_logits, Xbd_centers_quantized[:, :, 2])
    loss_latent = criterion(latent_logits, Xbd_encodings)
    loss = loss_x + loss_y + loss_z + loss_latent
    # loss = loss_x + loss_y + loss_z

    # return loss, loss_x.item(), loss_y.item(), loss_z.item()
    return loss, loss_x.item(), loss_y.item(), loss_z.item(), loss_latent.item()


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, vqvae: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, log_writer=None,
                    start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
                    num_training_steps_per_epoch=None, update_freq=None):
    model.train(True)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    if loss_scaler is None:
        model.zero_grad()
        model.micro_steps = 0
    else:
        optimizer.zero_grad()

    for data_iter_step, (_, _, Xbd, Xct) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # for data_iter_step, (_, _, surface) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        step = data_iter_step // update_freq
        if step >= num_training_steps_per_epoch:
            continue
        it = start_steps + step  # global training iteration
        # Update LR & WD for the first acc
        if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
            for i, param_group in enumerate(optimizer.param_groups):
                if lr_schedule_values is not None:
                    param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
                if wd_schedule_values is not None and param_group["weight_decay"] > 0:
                    param_group["weight_decay"] = wd_schedule_values[it]

        Xbd = Xbd.to(device, non_blocking=True)
        Xct = Xct.to(device, non_blocking=True)

        if loss_scaler is None:
            raise NotImplementedError
        else:
            with torch.cuda.amp.autocast():
                # loss, loss_x, loss_y, loss_z = train_batch(model, vqvae, Xbd, Xct, criterion, device)
                loss, loss_x, loss_y, loss_z, loss_latent = train_batch(model, vqvae, Xbd, Xct, criterion, device)
                # loss, loss_x, loss_y, loss_z, loss_latent = train_batch(model, vqvae, surface, criterion)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        if loss_scaler is None:
            raise NotImplementedError
        else:
            # this attribute is added by timm on one optimizer (adahessian)
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss /= update_freq  # 这里可能可以进行梯度累计！！！
            grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
                                    parameters=model.parameters(), create_graph=is_second_order,
                                    update_grad=(data_iter_step + 1) % update_freq == 0)
            if (data_iter_step + 1) % update_freq == 0:
                optimizer.zero_grad()
                if model_ema is not None:
                    model_ema.update(model)
            loss_scale_value = loss_scaler.state_dict()["scale"]

        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)
        metric_logger.update(loss_scale=loss_scale_value)
        metric_logger.update(loss_x=loss_x)
        metric_logger.update(loss_y=loss_y)
        metric_logger.update(loss_z=loss_z)
        metric_logger.update(loss_latent=loss_latent)

        min_lr = 10.
        max_lr = 0.
        for group in optimizer.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])

        metric_logger.update(lr=max_lr)
        metric_logger.update(min_lr=min_lr)
        weight_decay_value = None
        for group in optimizer.param_groups:
            if group["weight_decay"] > 0:
                weight_decay_value = group["weight_decay"]
        metric_logger.update(weight_decay=weight_decay_value)
        metric_logger.update(grad_norm=grad_norm)

        if log_writer is not None:
            log_writer.update(loss=loss_value, head="loss")
            log_writer.update(loss_scale=loss_scale_value, head="opt")
            log_writer.update(lr=max_lr, head="opt")
            log_writer.update(min_lr=min_lr, head="opt")
            log_writer.update(weight_decay=weight_decay_value, head="opt")
            log_writer.update(grad_norm=grad_norm, head="opt")

            log_writer.set_step()

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, vqvae, device):
    criterion = torch.nn.NLLLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for batch in metric_logger.log_every(data_loader, 1000, header):
        _, _, surface, categories = batch
        surface = surface.to(device, non_blocking=True)
        categories = categories.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                _, _, centers_quantized, _, _, encodings = vqvae.encode(surface)

            centers_quantized, encodings = sort(centers_quantized, encodings)

            x_logits, y_logits, z_logits, latent_logits = model(centers_quantized, encodings, categories)

            loss_x = criterion(x_logits, centers_quantized[:, :, 0])
            loss_y = criterion(y_logits, centers_quantized[:, :, 1])
            loss_z = criterion(z_logits, centers_quantized[:, :, 2])

            loss_latent = criterion(latent_logits, encodings)
            loss = loss_x + loss_y + loss_z + loss_latent

        batch_size = surface.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.update(loss_x=loss_x.item())
        metric_logger.update(loss_y=loss_y.item())
        metric_logger.update(loss_z=loss_z.item())
        metric_logger.update(loss_latent=loss_latent.item())
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* loss {losses.global_avg:.3f} '
          .format(losses=metric_logger.loss))
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
