import os
import glob
import torch
import traceback
from lib.train.admin import multigpu
from torch.utils.data.distributed import DistributedSampler


class BaseTrainer:
    """Base trainer class. Contains functions for training and saving/loading checkpoints.
    Trainer classes should inherit from this one and overload the train_epoch function."""

    def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
        """
        args:
            actor - The actor for training the network
            loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
                        epoch for each loader.
            optimizer - The optimizer used for training, e.g. Adam
            settings - Training settings
            lr_scheduler - Learning rate scheduler
        """
        self.actor = actor
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loaders = loaders

        self.update_settings(settings)

        self.epoch = 0
        self.stats = {}

        self.device = getattr(settings, 'device', None)
        if self.device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() and settings.use_gpu else "cpu")

        self.actor.to(self.device)
        self.settings = settings

    def update_settings(self, settings=None):
        """Updates the trainer settings. Must be called to update internal settings."""
        if settings is not None:
            self.settings = settings

        if self.settings.env.workspace_dir is not None:
            self.settings.env.workspace_dir = os.path.expanduser(self.settings.env.workspace_dir)
            '''2021.1.4 New function: specify checkpoint dir'''
            if self.settings.save_dir is None:
                self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir, 'checkpoints')
            else:
                self._checkpoint_dir = os.path.join(self.settings.save_dir, 'checkpoints')
            print("checkpoints will be saved to %s" % self._checkpoint_dir)

            if self.settings.local_rank in [-1, 0]:
                if not os.path.exists(self._checkpoint_dir):
                    print("Training with multiple GPUs. checkpoints directory doesn't exist. "
                          "Create checkpoints directory")
                    os.makedirs(self._checkpoint_dir)
        else:
            self._checkpoint_dir = None

    def train(self, max_epochs, load_latest=False, fail_safe=True, load_previous_ckpt=False, distill=False):
        """Do training for the given number of epochs.
        args:
            max_epochs - Max number of training epochs,
            load_latest - Bool indicating whether to resume from latest epoch.
            fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
        """

        epoch = -1
        num_tries = 1
        for i in range(num_tries):
            try:
                if load_latest:
                    self.load_checkpoint()
                if load_previous_ckpt:
                    if hasattr(self.settings, 'project_path_prv'):
                        directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_prv)
                        self.load_state_dict(directory)
                    else:
                        print("警告: 设置中缺少project_path_prv属性，跳过加载先前检查点")
                if distill:
                    if hasattr(self.settings, 'project_path_teacher'):
                        directory_teacher = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path_teacher)
                        self.load_state_dict(directory_teacher, distill=True)
                    else:
                        print("警告: 设置中缺少project_path_teacher属性，跳过加载教师模型")
                for epoch in range(self.epoch+1, max_epochs+1):
                    self.epoch = epoch

                    self.train_epoch()

                    if self.lr_scheduler is not None:
                        if self.settings.scheduler_type != 'cosine':
                            # 对于 ReduceLROnPlateau，需要传入指标值，但这里没有可用的指标
                            # 所以跳过 step 调用，让子类处理
                            if type(self.lr_scheduler).__name__ == 'ReduceLROnPlateau':
                                # ReduceLROnPlateau 需要指标，这里跳过，让子类处理
                                pass
                            else:
                                self.lr_scheduler.step()
                        else:
                            self.lr_scheduler.step(epoch - 1)
                    # only save the last 10 checkpoints
                    save_every_epoch = getattr(self.settings, "save_every_epoch", False)
                    save_epochs = [79, 159, 239]
                    if epoch > (max_epochs - 1) or save_every_epoch or epoch % 1 == 0 or epoch in save_epochs or epoch > (max_epochs - 3):
                    # if epoch > (max_epochs - 10) or save_every_epoch or epoch % 100 == 0:
                        if self._checkpoint_dir:
                            if self.settings.local_rank in [-1, 0]:
                                self.save_checkpoint()
            except:
                print('Training crashed at epoch {}'.format(epoch))
                if fail_safe:
                    self.epoch -= 1
                    load_latest = True
                    print('Traceback for the error!')
                    print(traceback.format_exc())
                    print('Restarting training from last epoch ...')
                else:
                    raise

        print('Finished training!')

    def train_epoch(self):
        raise NotImplementedError

    def save_checkpoint(self):
        """Saves a checkpoint of the network and other variables."""

        net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net

        actor_type = type(self.actor).__name__
        net_type = type(net).__name__
        state = {
            'epoch': self.epoch,
            'actor_type': actor_type,
            'net_type': net_type,
            'net': net.state_dict(),
            'net_info': getattr(net, 'info', None),
            'constructor': getattr(net, 'constructor', None),
            'optimizer': self.optimizer.state_dict(),
            'stats': self.stats,
            'settings': self.settings
        }

        # 检查project_path属性是否存在
        if not hasattr(self.settings, 'project_path'):
            print("警告: 设置中缺少project_path属性，使用默认值'default'")
            self.settings.project_path = 'default'

        directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
        print(directory)
        if not os.path.exists(directory):
            print("directory doesn't exist. creating...")
            os.makedirs(directory)

        # First save as a tmp file
        tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
        torch.save(state, tmp_file_path)

        file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)

        # Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
        os.rename(tmp_file_path, file_path)

    def load_checkpoint(self):
        """
        加载训练检查点
        """
        # 检查settings对象是否有checkpoint属性，如果没有则直接返回
        if not hasattr(self.settings, 'checkpoint') or self.settings.checkpoint is None:
            print("未指定检查点路径，跳过加载")
            return

        if isinstance(self.settings.checkpoint, list):
            raise RuntimeError("Multiple checkpoints not implemented")

        if os.path.isdir(self.settings.checkpoint):
            checkpoint_list = sorted(glob.glob(os.path.join(self.settings.checkpoint, "*.pth.tar")))
            if checkpoint_list:
                checkpoint_path = checkpoint_list[-1]
            else:
                print("No checkpoint found")
                return
        else:
            checkpoint_path = self.settings.checkpoint

        # Load checkpoint
        if self.settings.device == 'cuda':
            checkpoint_dict = torch.load(checkpoint_path)
        else:
            checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')

        # 检查是否存在不兼容的参数（如使用动态状态转移矩阵添加的参数）
        missing_keys = []
        for key, net in self.actor.net.items():
            if key in checkpoint_dict and not all(k in checkpoint_dict[key] for k in net.state_dict().keys()):
                # 如果有缺失的键，记录下来并使用非严格模式加载
                print(f"检测到模型{key}的参数与检查点不完全匹配，将使用非严格模式加载")
                missing_keys.append(key)
            
        # 加载检查点
        for key, net in self.actor.net.items():
            if key in checkpoint_dict:
                # 对于有缺失键的网络，使用非严格模式加载
                if key in missing_keys:
                    net.load_state_dict(checkpoint_dict[key], strict=False)
                    # 打印一些缺失键的信息
                    missing = [k for k in net.state_dict().keys() if k not in checkpoint_dict[key]]
                    if missing:
                        print(f"模型{key}缺失的键: {len(missing)}个, 例如: {missing[:5]}")
                else:
                    # 其他网络使用严格模式加载
                    net.load_state_dict(checkpoint_dict[key])
            else:
                print("No checkpoint found for", key)

        # 加载其他信息
        if 'actor_state_dict' in checkpoint_dict and self.actor.load_state_dict:
            self.actor.load_state_dict(checkpoint_dict['actor_state_dict'])
        if 'optimizer' in checkpoint_dict and self.optimizer is not None:
            self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
        if 'scheduler' in checkpoint_dict and self.lr_scheduler is not None:
            self.lr_scheduler.load_state_dict(checkpoint_dict['scheduler'])

        # 加载训练信息
        if 'epoch' in checkpoint_dict:
            self.epoch = checkpoint_dict['epoch'] + 1
        if 'stats' in checkpoint_dict:
            self.stats = checkpoint_dict['stats']

        # 重置路径
        if 'settings' in checkpoint_dict:
            # 检查settings和env_settings属性是否存在
            if hasattr(checkpoint_dict['settings'], 'env_settings') and hasattr(self.settings, 'env_settings'):
                if hasattr(checkpoint_dict['settings'].env_settings, 'workspace_dir') and hasattr(self.settings.env_settings, 'workspace_dir'):
                    if checkpoint_dict['settings'].env_settings.workspace_dir != self.settings.env_settings.workspace_dir:
                        checkpoint_dict['settings'].env_settings.workspace_dir = self.settings.env_settings.workspace_dir
        
        # 更新设置
        if 'settings' in checkpoint_dict:
            self.settings = checkpoint_dict['settings']

        print("加载检查点完成，从epoch {}恢复训练".format(self.epoch))

    def load_state_dict(self, checkpoint=None, distill=False):
        """Loads a network checkpoint file.

        Can be called in three different ways:
            load_checkpoint():
                Loads the latest epoch from the workspace. Use this to continue training.
            load_checkpoint(epoch_num):
                Loads the network at the given epoch number (int).
            load_checkpoint(path_to_checkpoint):
                Loads the file from the given absolute path (str).
        """
        if distill:
            net = self.actor.net_teacher.module if multigpu.is_multi_gpu(self.actor.net_teacher) \
                else self.actor.net_teacher
        else:
            net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net

        net_type = type(net).__name__

        if isinstance(checkpoint, str):
            # checkpoint is the path
            if os.path.isdir(checkpoint):
                checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
                if checkpoint_list:
                    checkpoint_path = checkpoint_list[-1]
                else:
                    raise Exception('No checkpoint found')
            else:
                checkpoint_path = os.path.expanduser(checkpoint)
        else:
            raise TypeError

        # Load network
        print("Loading pretrained model from ", checkpoint_path)
        checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')

        assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'

        missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
        print("previous checkpoint is loaded.")
        print("missing keys: ", missing_k)
        print("unexpected keys:", unexpected_k)

        return True
