import os
import torch
from arguments import get_args
from models import get_model
from tools import laps_update_wandb
from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel
import wandb
from glob import glob
import time, datetime

from utils.simmim_logger import create_logger
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from timm.utils import accuracy, AverageMeter

#----------
from logger import create_logger
import numpy as np

def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt

@torch.no_grad()
def validate(past_task_id, task_id, n_tasks, n_classes_per_task, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        if type(target) == list:
            target = list(map(int, target))
            target = torch.Tensor(target).type(torch.LongTensor)
            target = target.cuda(non_blocking=True)
        else:
            target = target.cuda(non_blocking=True)
        output = model.net(images)
        
        # measure accuracy and record loss
        loss = criterion(output, target)    
        output[:, :past_task_id*n_classes_per_task].data.fill_(-10e10)
        output[:, (past_task_id+1)*n_classes_per_task:].data.fill_(-10e10)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
    logger.info(f' * [T{past_task_id} Validation during T{task_id}] Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg


def auto_resume_helper(output_dir, logger):
    time_info = output_dir.split('/')[-1].split('_')[0]
    output_dirs = output_dir.replace(time_info, '*')
    output_dirs = glob(f'{output_dirs}')
    checkpoints = []
    for output_dir in output_dirs:
        ckpts = os.listdir(output_dir)
        ckpts = [os.path.join(output_dir, ckpt) for ckpt in ckpts if ckpt.endswith('pth') if '6.pth' in ckpt]
        logger.info(f"All checkpoints founded in {output_dir}: {ckpts}")
        checkpoints += ckpts

    if len(checkpoints) > 0:
        latest_checkpoint = max(checkpoints, key=os.path.getmtime)
        logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
        resume_file = latest_checkpoint
    else:
        resume_file = None
    return resume_file

def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
    logger.info(f">>>>>>>>>> Resuming from {config.model.resume} ..........")
    if config.model.resume.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(
            config.model.resume, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(config.model.resume, map_location='cpu')

    checkpoint_model = {k.replace('module.','net.'): v for k, v in checkpoint['state_dict'].items()}
    msg = model.load_state_dict(checkpoint_model, strict=False)
    logger.info(msg)

    if hasattr(model, 'buffer'):
        model.buffer = checkpoint['buffer']

    if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    if 'epoch' in checkpoint:
        args.train.start_epoch = (checkpoint['epoch']) % config.num_epochs
        args.train.start_task = checkpoint['task_id'] + 1
    logger.info(f"=> loaded successfully '{config.model.resume}' (epoch {checkpoint['epoch']})")

    del checkpoint
    torch.cuda.empty_cache()


def load_pretrained_checkpoint(config, model, logger):
    logger.info(f">>>>>>>>>> Initializing from {config.pretrained} ..........")
    checkpoint = torch.load(config.pretrained, map_location='cpu')
    
    checkpoint_model = {k.replace('module.','net.'): v for k, v in checkpoint['state_dict'].items()}
    msg = model.load_state_dict(checkpoint_model, strict=False)
    logger.info(msg)

    if 'epoch' in checkpoint:
        args.train.start_epoch = (checkpoint['epoch']) % config.num_epochs
        args.train.start_task = checkpoint['task_id'] + 1
    logger.info(f"=> loaded successfully '{config.pretrained}' (epoch {checkpoint['epoch']})")
    del checkpoint
    torch.cuda.empty_cache()

def load_pretrained(config, model, logger):
    logger.info(f">>>>>>>>>> Initializing from {config.pretrained} ..........")
    checkpoint = torch.load(config.pretrained, map_location='cpu')
    checkpoint_model = checkpoint['model']

    checkpoint_model = {k.replace(k, 'net.%s'%k): v for k, v in checkpoint['model'].items()}
    model.load_state_dict(checkpoint_model, strict=False)
    logger.info(f"=> loaded successfully '{config.pretrained}'")
    del checkpoint
    torch.cuda.empty_cache()

def evaluate(model: ContinualModel, dataset: ContinualDataset, device):
    pass

def main(device, args, wandb=None):
    logger.info(f'main')
    results = {}
    dataset = get_dataset(args)
    model = get_model(args, device, dataset.get_transform(args), logger)

    if args.pretrained:
        load_pretrained_checkpoint(args, model, logger)
        
    elif args.train.auto_resume:
        resume_file = auto_resume_helper(args.ckpt_dir, logger)
        if resume_file:
            if args.model.resume:
                logger.warning(f"auto-resume changing resume file from {args.model.resume} to {resume_file}")
            args.model.resume = resume_file
            logger.info(f'auto resuming from {resume_file}')
        elif args.model.resume:
            logger.info(f'target resuming from {args.model.resume}')
        else:
            logger.info(f'no checkpoint found in {args.ckpt_dir}, ignoring auto resume')

        if args.model.resume:
            lr_scheduler = model.lr_scheduler if hasattr(model, 'lr_scheduler') else None
            optimizer = model.opt if hasattr(model, 'opt') else None
            load_checkpoint(args, model, optimizer, lr_scheduler, logger)
    else:
        pass

    logger.info(f'initialize task {args.train.start_task}')
    train_loaders, test_loaders = [], []
      
    for t in range(args.train.last_task):
      train_loader, _, tel  = dataset.get_data_loaders()        
      test_loaders.append(tel)
      len_train_loader = len(train_loader)
      if args.train.start_task > t:
          logger.info(f'{args.logger_name}: continue task {t}')
          continue
      else:
          logger.info(f'{args.logger_name}: start task {t}, len_train_loader: {len_train_loader}')

      if hasattr(model, 'set_task'):
          logger.info(f'set task {t}')
          model.set_task(t, len_train_loader)

      start_time = time.time()
      for epoch in range(0, args.num_epochs):
        start = time.time()
        model.train()

        if args.amp_opt_level != "O0":
            train_loader.sampler.set_epoch(epoch)

        tr_losses = 0.
        tr_p_losses = 0.
        # training phase
        if args.gpukeeper:
            for images, mask, _ in train_loader:
                model.gkeep_observe(images, mask)
                                
        elif args.model.backbone == 'simmim':
            for idx, (images, mask, image_paths) in enumerate(train_loader):
                data_dict = model.masked_observe(images, mask, image_paths, t, batch_idx=idx, epoch_idx=epoch, num_steps=len_train_loader)
                tr_losses += data_dict['loss']
                tr_p_losses += data_dict['penalty']
        elif args.model.backbone == 'mae':
            for idx, (images, image_paths) in enumerate(train_loader):
                data_dict = model.mae_observe(images, image_paths, t, batch_idx=idx, epoch_idx=epoch, num_steps=len_train_loader)
                tr_losses += data_dict['loss']
                tr_p_losses += data_dict['penalty']
        elif args.model.backbone == 'simsiam':
            for idx, (images1, images2, image_paths) in enumerate(train_loader):
                data_dict = model.siamese_observe(images1, images2, image_paths, t, batch_idx=idx, epoch_idx=epoch, num_steps=len_train_loader)
                tr_losses += data_dict['loss']
                tr_p_losses += data_dict['penalty']
        elif args.model.backbone == 'supervised':
            for idx, (images, image_paths_and_labels) in enumerate(train_loader):
                labels = torch.Tensor(np.array(image_paths_and_labels, dtype=np.int16)).to(torch.int32).type(torch.LongTensor)
                data_dict = model.supervised_observe(images, labels, t, batch_idx=idx, epoch_idx=epoch, num_steps=len_train_loader)
                tr_losses += data_dict['loss']
                tr_p_losses += data_dict['penalty']
        else:
            pass
        
        
        epoch_time = time.time() - start
        logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
              
        # logger phase
        if dist.get_rank() == 0:
            if args.logger.wandb:
                laps_update_wandb(args, wandb, data_dict, epoch, t)        

      total_time = time.time() - start_time
      total_time_str = str(datetime.timedelta(seconds=int(total_time)))
      logger.info(f'TASK {t}, TOTAL TRAINING TIME {total_time_str}')
      # save phase    
      
      torch.cuda.synchronize()
      if args.local_rank == 0:
        model_path = os.path.join(args.ckpt_dir, f"{args.name}_" + str(t) + '.pth')
        save_dict = {
            'epoch': epoch+1,
            'state_dict':model.net.state_dict(),
            'task_id': t,
            }
        if hasattr(model, 'buffer'):
            save_dict['buffer'] = model.buffer
        torch.save(save_dict, model_path)
    
      if dist.get_rank() == 0:
        logger.info(f"Task Model saved to {model_path}")
        with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
            f.write(f'{model_path}')
      
      if hasattr(model, 'end_task'):
        model.end_task(logger)

      logger.info(f'task proceeding {t}/{args.train.last_task}')

if __name__ == "__main__":
    args = get_args()
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1

    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    dist.barrier()
    cudnn.benchmark = True
    logger_name = args.name.replace('SeqImageNet','').replace('notag','').replace('_','').replace('e-','em').replace('-','')[:-4]
    logger_name = logger_name.replace('patchselectionorignal','')
    logger = create_logger(output_dir=args.log_dir, dist_rank=dist.get_rank(), name=f"{args.cl_model}_{args.tag}")
    if dist.get_rank() == 0:
        logger.info(f'{args}')
        if args.logger.wandb:
            pass
    else:
        wandb = None
    
    args.logger_name=logger_name

    args.init_lr = args.init_lr * args.lr_multiplier * dist.get_world_size() * args.batch_size / 512
    args.train.warmup_lr = args.train.warmup_lr * args.lr_multiplier * dist.get_world_size() * args.batch_size / 512
    args.train.min_lr = args.train.min_lr * args.lr_multiplier * dist.get_world_size() * args.batch_size / 512

    if dist.get_rank() == 0:
        path = os.path.join(args.log_dir, "config.json")
        with open(path, "w") as f:
            f.write(str(vars(args)))
        logger.info(f"Full config saved to {path}")

    main(device=args.device, args=args, wandb=wandb)
    dist.barrier()
    if dist.get_rank() == 0:
        completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed')
        os.rename(args.log_dir, completed_log_dir)
        logger.info(f'Log file has been saved to {completed_log_dir}')
