import os
from typing import Optional
import argparse
import time
import json
import yaml
import copy
from pathlib import Path
from collections import OrderedDict

import torch
from torch import nn
from torch import optim
from torchvision.utils import save_image

from models import utils

from core import get_engine

from dataset import get_dataset, get_dataloader
from dataset.utils import load_task_config
from models import get_model, get_continual_model


try:
    import wandb
    has_wandb = True
except ImportError:
    has_wandb = False
    

def get_parser():
    parser = argparse.ArgumentParser(description='Slot attention')

    # Dataset parameters
    parser.add_argument('--data_root', metavar='DIR', default='',
                    help='path to dataset (root dir)')
    parser.add_argument('--dataset', metavar='NAME', default='CLEVR',
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
    parser.add_argument('--num_train_images', default=None, type=int,
                    metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
    parser.add_argument('--num_val_images', default=None, type=int,
                    metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
    parser.add_argument('--n_samples', type=int, default=0, metavar='N',
                    help='')
    parser.add_argument('--sample_interval', type=int, default=5, metavar='N',
                    help='')
    parser.add_argument('--eff_eval', action='store_true', default=False,
                    help='')
    
    

    # Device & distributed
    parser.add_argument('--device', default='cuda', type=str,
                        help="Device (accelerator) to use.")
    parser.add_argument('--amp', action='store_true', default=False,
                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
    parser.add_argument('--use_fp16', action='store_true', default=False, 
                        help="Whether or not to use half precision for training.")
    parser.add_argument('--synchronize_step', action='store_true', default=False,
                    help='torch.cuda.synchronize() end of each step')
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument('--device_modules', default=None, type=str, nargs='+',
                        help="Python imports for device backend modules.")
    

    # Model parameters
    parser.add_argument('--arch', type=str, default='base', metavar='N',
                   help='model architecture')
    parser.add_argument('-b', '--batch_size', type=int, default=64, metavar='N',
                   help='Input batch size for training (default: 64)')
    parser.add_argument('-vb', '--val_batch_size', type=int, default=64, metavar='N',
                    help='Validation batch size override (default: 64)')
    parser.add_argument('--resolution', type=int, default=(128, 128), metavar='N', nargs='+',
                   help='Image size (default: None => model default)')
    parser.add_argument('--num_slots', type=int, default=5, metavar='N',
                   help='Number of slots')
    parser.add_argument('--num_iterations', type=int, default=3, metavar='N',
                   help='Number of iterations in slot attention')
    parser.add_argument('--empty_cache', action='store_true', default=False,
                   help='')
    parser.add_argument('--div_lambda', type=float, default=0.001,
                   help='diversity loss balancing parameter')
    parser.add_argument('--local_alpha', type=float, default=0.5,
                   help='local diversity loss balancing parameter')
    parser.add_argument('--save_weights', action='store_true', default=False,
                    help='')

    

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.0,
                   help='weight decay (default: 0.0)')
    

    # Learning rate schedule parameters
    parser.add_argument('--lr', type=float, default=0.0004, metavar='LR',
                   help='learning rate, overrides lr_base if set (default: None)')
    parser.add_argument('--num_epochs', type=int, default=100, metavar='N',
                   help='number of epochs to train (default: 100)')
    parser.add_argument('--steps', type=int, default=-1, metavar='N',
                   help='number of steps to train (default: -1)')
    parser.add_argument('--end_epoch', type=int, default=0, metavar='N',
                   help='if True, use if as early stop')
    parser.add_argument('--num_sanity_val_steps', type=int, default=1, metavar='N',
                   help='number of sanity validation steps')
    parser.add_argument('--scheduler_gamma', type=float, default=0.5, metavar='RATE',
                   help='LR decay rate (default: 0.1)')
    parser.add_argument('--warmup_steps_pct', type=float, default=0.02, metavar='N',
                   help='')
    parser.add_argument('--decay_steps_pct', type=float, default=0.2, metavar='N',
                   help='')

    

    # Misc
    parser.add_argument('--notes', default='',
                        help='notes of current experiments')
    parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
    parser.add_argument('--log_interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
    parser.add_argument('-j', '--num_workers', type=int, default=4, metavar='N',
                    help='how many training processes to use (default: 4)')
    parser.add_argument('--pin_mem', action='store_true', default=False,
                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--output', default="./results", type=str, metavar='PATH',
                    help='path to output folder (default: none, current dir)')
    parser.add_argument('--experiment', default='', type=str, metavar='NAME',
                    help='name of train experiment, name of sub_folder for output')
    parser.add_argument('--project', default='Slot Attention symmetry', type=str, metavar='NAME',
                    help='name of the project')
    parser.add_argument('--is_logger_enabled', action='store_true', default=True,
                    help='')
    parser.add_argument('--is_verbose', action='store_true', default=True,
                    help='')
    parser.add_argument('--log_wandb', action='store_true', default=False,
                    help='log training and validation metrics to wandb')

    # Evaluation
    parser.add_argument('--eval_only', action='store_true', default=False,
                    help='')
    parser.add_argument('--eval_metrics', action='store_true', default=False,
                    help='')
    parser.add_argument('--checkpoint_dir', type=str, default='', metavar='N',
                   help='checkpoint directory')
    
    
    # Continual
    parser.add_argument('--continual_arch', type=str, default='base', metavar='N',
                   help='continual model architecture',)
    parser.add_argument('--task_config', metavar='DIR', default='',
                    help='path to task config')
    parser.add_argument('--num_task', type=int, default=5, metavar='N',
                   help='number of tasks to train (default: 100)')
    
    # isoloation
    parser.add_argument('--param_isolation', action='store_true', default=False,
                    help='sharing parameters across tasks')
    parser.add_argument('--isol_params', type=str, default=['decoder', ], metavar='N', nargs='+',
                   help='')
    
    # freeze
    parser.add_argument('--param_freeze', action='store_true', default=False,
                    help='freezing parameters across tasks')
    parser.add_argument('--freeze_params', type=str, default=['decoder', ], metavar='N', nargs='+',
                   help='')
    
    # Replay
    parser.add_argument('--replay_size', type=int, default=-1, metavar='N',
                   help='number of samples for replay') 
    parser.add_argument('--replay_epochs', type=int, default=50, metavar='N',
                   help='number of epochs for replay') 
    
    # resume
    parser.add_argument('--resume', action='store_true', default=False,
                    help='')
    parser.add_argument('--resume_checkpoint', type=str, default='',
                    help='')
    return parser.parse_args()
    

def main(params):
        
    # assert params.num_slots > 1, "Must have at least 2 slots."

    device = utils.init_distributed_device(params)  
    utils.random_seed(params.seed, params.rank)

    # log wandb
    if params.log_wandb and utils.is_main_process():
        if has_wandb:
            project = params.project
            run = wandb.init(project=project, config=params, entity="", name=params.experiment)
        else:
           print(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")


    # save argparser hyper-parameters
    if utils.is_main_process():
        params_text = yaml.safe_dump(params.__dict__, default_flow_style=False)
        with open(os.path.join(params.output_dir, 'args.yaml'), 'w') as f:
            f.write(params_text)
    
    if params.is_verbose:
        print(f"INFO: limiting the dataset to only images with `num_slots - 1` ({params.num_slots - 1}) objects.")
        if params.num_train_images:
            print(f"INFO: restricting the train dataset size to `num_train_images`: {params.num_train_images}")
        if params.num_val_images:
            print(f"INFO: restricting the validation dataset size to `num_val_images`: {params.num_val_images}")

    
    # load dataset (for class incremental learning)
    params, config = load_task_config(params)
    datasets, params = get_dataset(params, config)
    dataloaders = get_dataloader(params=params, dataset=datasets, config=None)


    # save dataset config file
    if utils.is_main_process():
        with open(os.path.join(params.output_dir, 'task_config.yaml'), 'w') as f:
            yaml.dump(config, f, default_flow_style=False, allow_unicode=True)



    # load models

    find_unused_parameters = True if 'freeze' in params.continual_arch.lower() else False
    model = get_continual_model(params)
    model = model.cuda()
    model = nn.parallel.DistributedDataParallel(model, device_ids=[params.gpu], find_unused_parameters=find_unused_parameters)
    model_without_ddp = model.module
    print(model_without_ddp)
    print(f'- Setting [--find_unused_parameters] to {find_unused_parameters}...')
    

    if utils.is_main_process() and params.log_wandb:
        wandb.save(os.path.join(params.output_dir, 'task_config.yaml'))
        wandb.save(os.path.join(params.output_dir, 'args.yaml'))


    # eval only
    if params.eval_metrics and params.eval_only:
        get_eval_metrics(params, model_without_ddp, dataloaders)
        exit()


    # train
    for task_idx in range(params.num_task):

        start_epoch = 0
        if params.end_epoch > 0:
            end_epoch = params.end_epoch
            print(f'Early stopping this run at {params.end_epoch} / {params.num_epochs}')
        else:
            if params.num_epochs != 0:
                end_epoch = params.num_epochs[task_idx]
                num_epochs = params.num_epochs[task_idx]
            else:
                end_epoch = params.num_epochs   
                num_epochs = params.num_epochs   

        train_dataset, val_dataset, _ = datasets[task_idx]
        train_dataloader, _, _ = dataloaders[task_idx]

        print(f"Training set size (images must have {params.num_slots - 1} objects):", len(train_dataset))
        print(f"Validation set size (images must have {params.num_slots - 1} objects):", len(val_dataset))

        no_train = False
        if task_idx > 0 and params.continual_arch == 'drwt':
            end_epoch = 1
            no_train = True

        begin_kargs = {}
        if 'dpr' in params.continual_arch:
            begin_kargs.update({
                'dataset': train_dataset,
                'is_main': utils.is_main_process(),
                'end_epoch': end_epoch,
            })
        if 'pr' in params.continual_arch:
            begin_kargs.update({
                'dataset': train_dataset,
                'is_main': utils.is_main_process(),
                'end_epoch': end_epoch,
            })
        model.module.begin_task(**begin_kargs)

        
            

        # optimizer
        parameters = []
        for name, param in model.named_parameters():
            if param.requires_grad == True:
                parameters.append(param)
                # if utils.is_main_process():
                #     print('***training: ', name, param.size())  
            else:
                print('*** Freezing: ', name, param.size(), '***')  

        if params.arch.lower() in ['monet', ]:
            optimizer = optim.RMSprop(parameters, lr=params.lr, weight_decay=params.weight_decay)
        else:   
            optimizer = optim.Adam(parameters, lr=params.lr, weight_decay=params.weight_decay)
        print(f'- Optimizer: ', optimizer)
        

        # lr scheduler
        warmup_steps_pct = params.warmup_steps_pct
        decay_steps_pct = params.decay_steps_pct
        total_steps = num_epochs * len(train_dataloader)

        def warm_and_decay_lr_scheduler(step: int):
            warmup_steps = warmup_steps_pct * total_steps
            decay_steps = decay_steps_pct * total_steps
            assert step < total_steps+1
            if step < warmup_steps:
                factor = step / warmup_steps
            else:
                factor = 1
            factor *= params.scheduler_gamma ** (step / decay_steps)
            return factor

        scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_and_decay_lr_scheduler)


        # amp scaler
        fp16_scaler = None
        if params.use_fp16:
            fp16_scaler = torch.cuda.amp.GradScaler()


        if params.n_samples > 0:
            perm = torch.randperm(params.batch_size)


        # train, valid, visualize engine
        train_one_epoch, valid_one_epoch, sample_images = get_engine(dataset=params.dataset.lower(), 
                                                                     arch=params.arch.lower(), 
                                                                     cont=params.continual_arch.lower())
        

        with (Path(params.output_dir) / f"log-task{task_idx}.txt").open("w") as f:
            pass
        


        # train task 
        for epoch in range(start_epoch, end_epoch):
            train_dataloader.sampler.set_epoch(epoch)

            torch.cuda.synchronize()

            print(f'\n- Current task {task_idx}') 


            # train

            if not no_train:
                train_stats = train_one_epoch(
                    epoch=epoch,
                    model=model,
                    loader=train_dataloader,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    use_amp=params.amp,
                )
            else:
                train_stats = {}
            
            break_true = False
            if params.resume and params.resume_checkpoint != '' and task_idx == 0:
                checkpoint_path = os.path.join(params.resume_checkpoint, f'checkpoint-task0.pth')
                assert os.path.exists(checkpoint_path), f'No file at {checkpoint_path}'
                print(f'\Resume checkpoints from {checkpoint_path}...\n')
                model_without_ddp.reload_checkpoint(task_num=task_idx, checkpoint_path=checkpoint_path)
                break_true = True
            else:
                break_true = False

            task_break_true = False
            if 'joint' in params.task_config:
                task_break_true = True
                

            model_without_ddp.update_checkpoint()


            model_without_ddp.inter_task()

            # save 
            save_dict = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1,
                'params': params,
            }
            if fp16_scaler is not None:
                save_dict['fp16_scaler'] = fp16_scaler.state_dict()

            if params.save_weights:
                utils.save_on_master(save_dict, os.path.join(params.output_dir, f'checkpoint-task{task_idx}.pth'))

            torch.distributed.barrier()
            torch.cuda.synchronize()

            # eval
            if params.eff_eval:
                if epoch in [0, num_epochs-2, num_epochs-1]:
                    pass
                else:
                    continue
            task_valid_stats = []
            task_images = []
            for idx in range(params.num_task):
                if utils.is_main_process():
                    print(f'\n- Evaluate task {idx} / Current task {task_idx}') 

                if params.param_isolation:
                    # model_without_ddp.load_isolated_checkpoint(task_num=idx)
                    if idx <= task_idx:
                        checkpoint_path=os.path.join(params.output_dir, f'checkpoint-task{idx}.pth')
                        model_without_ddp.load_isolated_checkpoint(task_num=task_idx, checkpoint_path=checkpoint_path)
                    else:
                        checkpoint_path=os.path.join(params.output_dir, f'checkpoint-task{task_idx}.pth')
                        model_without_ddp.reload_checkpoint(task_num=task_idx, checkpoint_path=checkpoint_path)
                
                # if idx ==  task_idx:
                valid_stats = valid_one_epoch(epoch, model_without_ddp, dataloaders[idx][1])
                task_valid_stats.append(valid_stats)

                # visualize
                if params.n_samples > 0 and (epoch % params.sample_interval == 0 or epoch == num_epochs-1 or (epoch == num_epochs-2 and params.eff_eval)): 
                    print(f'- Reconstruction sample {idx} / Current task {task_idx}') 
                    images = sample_images(model_without_ddp, dataloaders[idx][-1], params.batch_size, params.n_samples, perm)
                    task_images.append(images)
                    if utils.is_main_process():
                        save_image(images, os.path.join(params.output_dir, f'images-task{task_idx}', f"img{epoch}-task{idx}.png"))
                else:
                    task_images = None

            if params.param_isolation:
                checkpoint_path=os.path.join(params.output_dir, f'checkpoint-task{task_idx}.pth')
                model_without_ddp.reload_checkpoint(task_num=task_idx, checkpoint_path=checkpoint_path)
            torch.distributed.barrier()


            # log
            log_stats = {**{f'task{task_idx}-train_{k}': v for k, v in train_stats.items()},
                        'epoch': epoch,
                        }
            for idx in range(params.num_task):
                # if idx ==  task_idx:
                log_stats.update({
                    **{f'task{idx}-valid_{k}': v for k, v in task_valid_stats[idx].items()},
                })

            if utils.is_main_process():
                with (Path(params.output_dir) / f"log-task{task_idx}.txt").open("a") as f:
                    f.write(json.dumps(log_stats) + "\n")

            if utils.is_main_process() and params.log_wandb:
                update_summary(
                    epoch=epoch,
                    train_metrics=train_stats,
                    eval_metrics=task_valid_stats,
                    lr=None,
                    log_wandb=params.log_wandb,
                    images=task_images,
                    current_task=task_idx,
                )
                
            torch.distributed.barrier()
            torch.cuda.synchronize()

            if break_true:
                break

        end_task_kargs = {}
        if 'replay' in params.continual_arch:
            end_task_kargs.update({
                'dataset': train_dataset,
                'is_main': utils.is_main_process(),
            })
        model_without_ddp.end_task(**end_task_kargs)
        torch.distributed.barrier()

        if task_break_true:
                break

        


    # eval_metrics
    if params.eval_metrics:
        get_eval_metrics(params, model_without_ddp, dataloaders)

            
    # run.finish()




def get_eval_metrics(params, model_without_ddp, dataloaders):
    from eval.eval_metrics import eval_metrics
    eval_metrics_stats = eval_metrics(
        params, model_without_ddp, dataloaders
    )
    if utils.is_main_process():
        with (Path(params.output_dir) / f"eval_metrics_res.yaml").open("w") as f:
            yaml.dump(eval_metrics_stats, f, default_flow_style=False, allow_unicode=True)

    
    if params.log_wandb and utils.is_main_process():
        wandb.save(Path(params.output_dir) / f"eval_metrics_res.yaml")

        metrs = ['ari', 'mse', 'mean_segcover', 'scaled_segcover']
        total_res = []
        for target_task in range(params.num_task):
            
            for current_task in range(target_task, params.num_task):
                task_stats = [target_task, current_task]
                res = eval_metrics_stats[target_task][current_task]
                
                for metr in metrs:
                    task_stats.append(res[metr])

                total_res.append(task_stats)
            my_table = wandb.Table(
                columns=['eval_task', 'train_task'] + metrs,
                data=total_res
            )
            wandb.log({'eval_matrics': my_table})




def update_summary(
    epoch,
    train_metrics,
    eval_metrics,
    lr=None,
    log_wandb=False,
    images=None,
    current_task=0,
):
    rowd = OrderedDict(epoch=epoch)
    rowd['task'] = current_task
    rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
    for task_idx in range(len(eval_metrics)):
        if eval_metrics[task_idx]:
            rowd.update([(f'eval_task{task_idx}_' + k, v) for k, v in eval_metrics[task_idx].items()])
    if lr is not None:
        rowd['lr'] = lr
    if images is not None:
        for task_idx in range(len(images)):
            wandb_images = wandb.Image(images[task_idx], caption=f'task{task_idx}_img{epoch}')
            rowd[f'task{task_idx}_reconstruction'] = wandb_images
    if log_wandb:
        wandb.log(rowd)
    


def get_experiment_name(params):
    task_config = params.task_config.split('/')[-1][:-5]
    experiment = [
            # 'slot_attention',
            params.arch,
            params.continual_arch,
            params.dataset,
            task_config,
            f'Isol{params.param_isolation}',
            f'S{params.num_slots}',
            f'SD{params.seed}',
        ]
    root = '_'.join(experiment[:4])
    experiment = '_'.join(experiment)
    
    if params.notes != '':
        experiment += f'_{params.notes}'

    return root, experiment


if __name__ == "__main__":

    params = get_parser()
    print("\n".join("-%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items())))
    
    if params.experiment == '':
        experiment = get_experiment_name(params=params)
        params.root, params.experiment = experiment

    params.output_dir = os.path.join(params.output, params.root, params.experiment)

    for task_idx in range(params.num_task):
        Path(f"{params.output_dir}/images-task{task_idx}").mkdir(parents=True, exist_ok=True)  

    
    main(params=params)

