import os
import h5py
import logging
import numpy as np
from time import time

import torch
import torch.optim as optim
from torch.optim import optimizer
from torch.utils.tensorboard import SummaryWriter

from syn_lib.model import sequenceNetwork,HOINetwork
from syn_lib.model.pointbert.pointbert import PointTransformer
from syn_lib.data import ArtIMageDataset
from syn_lib import utils
from syn_lib.utils import AvgRecorder, NetworkType
from tools.utils import io
import torch.nn as nn

def custom_collate_fn(batch):
    input_batch, gt_batch, index = zip(*batch)  
    from torch.utils.data._utils.collate import default_collate
    input_data_batch = default_collate(input_batch)
    index = default_collate(index)

    gt_dict_batch = {}
    for key in gt_batch[0]:
        values = [gt[key] for gt in gt_batch]
        try:
            gt_dict_batch[key] = default_collate(values)
        except:
            gt_dict_batch[key] = values  #

    return input_data_batch, gt_dict_batch, index

class artimageTrainer:
    def __init__(self, cfg, dataPath, network_type):
        numPart = 2
        self.cfg = cfg
        device = torch.device(cfg.device)
        self.device = device
        self.dataPath = dataPath
        self.max_epochs = cfg.network.max_epochs
        self.network_type = network_type
        self.model = self.build_model()
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        
        self.model.to(device)
        
        self.log = logging.getLogger("Network")
        self.log.info(f"Using device {self.device}")
        
        self.optimizer = optim.Adam(
            self.model.parameters(), lr=cfg.network.lr, betas=(0.9, 0.99)
        )
        
        self.writer = None
        self.train_loader = None
        self.test_loader = None
        self.test_result = None
        self.init_data_loader(self.cfg.eval_only)
    
    def build_model(self):
        if self.network_type == NetworkType.VQVAE:
            # model = vqvaeBased(cfg=self.cfg)
            model = sequenceNetwork(cfg=self.cfg)
        else:
            model = HOINetwork(cfg=self.cfg)
        return model
    
    def init_data_loader(self, eval_only):
        if not eval_only:
            self.train_loader = torch.utils.data.DataLoader(
                ArtIMageDataset(
                    self.dataPath["train"], numPoints=self.cfg.network.num_points
                ),
                batch_size=self.cfg.network.batch_size,
                shuffle=True,
                num_workers=self.cfg.network.num_workers,
                pin_memory=True,
                persistent_workers=True,
                # collate_fn=custom_collate_fn
            )
            self.log.info(f'Num {len(self.train_loader)} batches in train loader')

        self.test_loader = torch.utils.data.DataLoader(
            ArtIMageDataset(
                self.dataPath["test"], numPoints=self.cfg.network.num_points
            ),
            batch_size=self.cfg.network.batch_size,
            shuffle=False,
            num_workers=self.cfg.network.num_workers,
            pin_memory=True,
            persistent_workers=True,
            # collate_fn=custom_collate_fn
        )
        self.log.info(f'Num {len(self.test_loader)} batches in test loader')
        
    def train_epoch(self, epoch):
        self.log.info(f'>>>>>>>>>>>>>>>> Train Epoch {epoch} >>>>>>>>>>>>>>>>')
        self.model.train()
        
        iter_time = AvgRecorder()
        io_time = AvgRecorder()
        to_gpu_time = AvgRecorder()
        network_time = AvgRecorder()
        start_time = time()
        end_time = time()
        remain_time = ''
        
        epoch_loss = {
            'total_loss': AvgRecorder()
        }
        
        for i, (inputDict, gt_dict, id) in enumerate(self.train_loader):
            io_time.update(time() - end_time)
            # Move the tensors to the device
            s_time = time()
            
            inputDict = {key: value.to(self.device, non_blocking=True) for key, value in inputDict.items()}
                        
            # PointCloud = PointCloud.to(self.device, non_blocking=True)
            gt = {}
            # for k, v in gt_dict.items():
            #     gt[k] = v.to(self.device, non_blocking=True)
            gt = {
                k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v
                for k, v in gt_dict.items()
            } # 
#
            to_gpu_time.update(time() - s_time)
            
            s_time = time()
            pred = self.model(inputDict, gt)
            # loss_dict = self.model.module.losses(pred, gt)
            if torch.cuda.device_count() > 1:
                loss_dict = self.model.module.losses(pred, gt)
            else:
                loss_dict = self.model.losses(pred, gt)
            network_time.update(time() - s_time)

            loss = torch.tensor(0.0, device=self.device)
            loss_weight = self.cfg.network.loss_weight
            for k, v in loss_dict.items():
                if k not in loss_weight:
                    raise ValueError(f"No loss weight for {k}")
                loss += loss_weight[k] * v
                
                        # Used to calculate the avg loss
            for k, v in loss_dict.items():
                if k not in epoch_loss.keys():
                    epoch_loss[k] = AvgRecorder()
                epoch_loss[k].update(loss_weight[k] * v)
            epoch_loss['total_loss'].update(loss)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # loss_log1 = ''
            for k, v in loss_dict.items():
            #     # loss_log1 += '{}: {:.5f}  '.format(k, v)
            #             for k, v in epoch_loss.items():
            # if k == "total_loss":
            #     self.writer.add_scalar(f"{k}", epoch_loss[k].avg, epoch)
            # else:
                self.writer.add_scalar(f"batchLoss_{epoch}/{k}", loss_dict[k], i + epoch * self.cfg.network.batch_size)
            # print(loss_log1)
            
            current_iter = epoch * len(self.train_loader) + i + 1
            max_iter = (self.max_epochs + 1) * len(self.train_loader)
            remain_iter = max_iter - current_iter

            iter_time.update(time() - end_time)
            end_time = time()

            remain_time = remain_iter * iter_time.avg
            remain_time = utils.duration_in_hours(remain_time)
        
        self.writer.add_scalar("lr", self.optimizer.param_groups[0]["lr"], epoch)
        # self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], epoch)
        # self.scheduler.step()
        # Add the loss values into the tensorboard
        for k, v in epoch_loss.items():
            if k == "total_loss":
                self.writer.add_scalar(f"{k}", epoch_loss[k].avg, epoch)
            else:
                self.writer.add_scalar(f"loss/{k}", epoch_loss[k].avg, epoch)
                
        if epoch % self.cfg.train.log_frequency == 0:
            loss_log = ''
            for k, v in epoch_loss.items():
                loss_log += '{}: {:.5f}  '.format(k, v.avg)

            self.log.info(
                'Epoch: {}/{} Loss: {} io_time: {:.2f}({:.4f}) to_gpu_time: {:.2f}({:.4f}) network_time: {:.2f}({:.4f}) \
                duration: {:.2f} remain_time: {}'
                    .format(epoch, self.max_epochs, loss_log, io_time.sum, io_time.avg, to_gpu_time.sum,
                            to_gpu_time.avg, network_time.sum, network_time.avg, time() - start_time, remain_time))
            
    def eval_epoch(self, epoch, save_results=False):
        val_error = {
            'total_loss': AvgRecorder()
        }
        
        avg_acc = {
            # 'total_acc' :  AvgRecorder()
        }
        
        if save_results:
            io.ensure_dir_exists(self.cfg.paths.network.test.output_dir)
            inference_path = os.path.join(self.cfg.paths.network.test.output_dir,
                                          "ART_" + self.cfg.paths.network.test.inference_result)
            self.test_result = h5py.File(inference_path, "w")
            print(inference_path)
            # self.test_result.attrs["network_type"] = self.network_type.value
        
        self.model.eval()
        with torch.no_grad():
            start_time = time()
            for i, (inputDict, gt_dict, id) in enumerate(self.test_loader):
                # Move the tensors to the device
                # inputData = inputData.to(self.device)
                
                inputDict = {key: value.to(self.device, non_blocking=True) for key, value in inputDict.items()}
                # PointCloud = PointCloud.to(self.device, non_blocking=True)
                gt = {}
                # for k, v in gt_dict.items():
                #     gt[k] = v.to(self.device)

                gt = {
                    k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v
                    for k, v in gt_dict.items()
                } # 
                
                pred = self.model(inputDict,gt)
                if save_results:
                    self.save_results(pred, inputDict["PointCloud"], gt, id)
                #loss_dict = self.model.module.losses(pred, gt)
                if torch.cuda.device_count() > 1:
                    loss_dict = self.model.module.losses(pred, gt)
                    acc_dict = self.model.module.accuracy(pred, gt)
                else:
                    loss_dict = self.model.losses(pred, gt)
                    acc_dict = self.model.accuracy(pred, gt)

                    # self.log.info("Accuracy Summary: seg_acc = %.4f, trans_acc = %.4f, acc_10deg = %.4f, mano_acc = %.4f",
                    #     acc_dict["seg_acc"], acc_dict["trans_acc"], acc_dict["acc_10deg"], acc_dict["mano_acc"])
                loss_weight = self.cfg.network.loss_weight
                loss = torch.tensor(0.0, device=self.device)
                # use different loss weight to calculate the final loss
                for k, v in loss_dict.items():
                    if k not in loss_weight:
                        raise ValueError(f"No loss weight for {k}")
                    loss += loss_weight[k] * v

                # Used to calculate the avg loss
                for k, v in loss_dict.items():
                    if k not in val_error.keys():
                        val_error[k] = AvgRecorder()
                    val_error[k].update(v)
                val_error['total_loss'].update(loss)
                
                for k, v in acc_dict.items():
                    if k not in avg_acc.keys():
                        avg_acc[k] = AvgRecorder()
                    avg_acc[k].update(v)
                # avg_acc['total_acc'].update(loss)
                            
        if self.writer is not None:
            for k, v in val_error.items():
                self.writer.add_scalar(f"val_error/{k}", val_error[k].avg, epoch)

        if self.writer is not None:
            for k, v in avg_acc.items():
                self.writer.add_scalar(f"avg_acc/{k}", avg_acc[k].avg, epoch)
        
        loss_log = ''
        for k, v in val_error.items():
            loss_log += '{}: {:.5f}  '.format(k, v.avg)

        self.log.info(
            'Eval Epoch: {}/{} Loss: {} duration: {:.2f}'
                .format(epoch, self.max_epochs, loss_log, time() - start_time))
        if save_results:
            self.test_result.close()
        return val_error
    
    def train(self, start_epoch=0):
        self.model.train()
        self.writer = SummaryWriter(self.cfg.paths.network.train.output_dir)

        io.ensure_dir_exists(self.cfg.paths.network.train.output_dir)

        best_model = None
        best_result = np.inf
        for epoch in range(start_epoch, self.max_epochs + 1):
            self.train_epoch(epoch)

            if epoch % self.cfg.train.save_frequency == 0 or epoch == self.max_epochs:
                # Save the model
                torch.save(
                    {
                        "epoch": epoch,
                        "model_state_dict": self.model.state_dict(),
                        "optimizer_state_dict": self.optimizer.state_dict(),
                    },
                    os.path.join(self.cfg.paths.network.train.output_dir,
                                 self.cfg.paths.network.train.model_filename % epoch),
                )

                val_error = self.eval_epoch(epoch)

                if best_model is None or val_error["total_loss"].avg < best_result:
                    best_model = {
                        "epoch": epoch,
                        "model_state_dict": self.model.state_dict(),
                        "optimizer_state_dict": self.optimizer.state_dict(),
                    }
                    best_result = val_error["total_loss"].avg
                    torch.save(
                        best_model,
                        os.path.join(self.cfg.paths.network.train.output_dir,
                                     self.cfg.paths.network.train.best_model_filename)
                    )
        self.writer.close()
        
    def get_latest_model_path(self, with_best=False):
        train_result_dir = os.path.dirname(self.cfg.paths.network.train.output_dir)
        folder, filename = utils.get_latest_file_with_datetime(train_result_dir,
                                                               "VQVAE" + '_', ext='.pth')
        model_path = os.path.join(train_result_dir, folder, filename)
        if with_best:
            model_path = os.path.join(train_result_dir, folder, self.cfg.paths.network.train.best_model_filename)
        return model_path
    
    def test(self, inference_model=None):
        if not inference_model or not io.file_exist(inference_model):
            inference_model = self.get_latest_model_path(with_best=True)
        if not io.file_exist(inference_model):
            raise IOError(f'Cannot open inference model {inference_model}')
        # Load the model
        self.log.info(f"Load model from {inference_model}") #就是这里报错
        checkpoint = torch.load(inference_model, map_location=self.device)
        epoch = checkpoint["epoch"]
        
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(self.device)

        self.eval_epoch(epoch, save_results=True)
        
    def save_results(self, pred, PointCloud, gt, id):
        # Save the results and gt into hdf5 for further optimization
        batch_size = pred["pred_pose"].shape[0]
        gt["obj_degree"] = gt["obj_degree"].unsqueeze(1)
        for b in range(batch_size):
            group = self.test_result.create_group(f"{id[b]}")
            group.create_dataset(
                "PointCloud",
                data=PointCloud[b].detach().cpu().numpy(),
                compression="gzip",
            )
            tokenList = pred['tokenList'][b].detach().cpu().numpy()
            heatmapFeature = pred['heatmapFeature'][b].detach().cpu().numpy()
            handPcFeature = pred['handPcFeature'][b].detach().cpu().numpy() 
            helpinfo = pred['helpinfo'][b].detach().cpu().numpy()
            pred_mesh_f3d = pred['pred_mesh_f3d'][b].detach().cpu().numpy()
            
            raw_pred_pose = pred['pred_pose'][b].detach().cpu().numpy()
            raw_pred_trans_delta = pred['trans_pred'][b].detach().cpu().numpy()
            raw_pred_rot_delta = pred['pred_rot_angle'][b].detach().cpu().numpy()
            hottest_points = pred['hottest_points'][b].detach().cpu().numpy()
            dirs_normalized = pred['dirs_normalized'][b].detach().cpu().numpy()
            pred_mesh_v3d = pred['pred_mesh_v3d'][b].detach().cpu().numpy()
            
            group.create_dataset('tokenList', data=tokenList, compression="gzip")
            group.create_dataset('pred_mesh_f3d', data=pred_mesh_f3d, compression="gzip")
            group.create_dataset('heatmapFeature', data=heatmapFeature, compression="gzip")
            group.create_dataset('handPcFeature', data=handPcFeature, compression="gzip")
            group.create_dataset('helpinfo', data=helpinfo, compression="gzip")
            group.create_dataset('raw_pred_pose', data=raw_pred_pose, compression="gzip")
            group.create_dataset('raw_pred_trans_delta', data=raw_pred_trans_delta, compression="gzip")
            group.create_dataset('raw_pred_rot_delta', data=raw_pred_rot_delta, compression="gzip")
            group.create_dataset('hottest_points', data=hottest_points, compression="gzip")
            group.create_dataset('dirs_normalized', data=dirs_normalized, compression="gzip")    
            group.create_dataset('pred_mesh_v3d', data=pred_mesh_v3d, compression="gzip")    

            # Save the gt
            for k, v in gt.items():
                
                if isinstance(v, torch.Tensor):
                    v = v.detach().cpu().numpy()
                
                group.create_dataset(
                    f"gt_{k}", data=v[b], compression="gzip"
                )
                
    def resume_train(self, model_path=None):
        if not model_path or not io.file_exist(model_path):
            model_path = self.get_latest_model_path()
        # Load the model
        if io.is_non_zero_file(model_path):
            checkpoint = torch.load(model_path, map_location=self.device)
            epoch = checkpoint["epoch"]
            self.log.info(f"Continue training with model from {model_path} at epoch {epoch}")
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            self.model.to(self.device)
        else:
            epoch = 0

        self.train(epoch)