import os
import torch

from models import utils
from eval.visualize import make_recon_img
from eval.ari import ari as ari_
from eval.segmentation_covering import segmentation_covering




def eval_metrics(params, model, dataloaders):

    
    if params.checkpoint_dir != '':
        checkpoint_dir = params.checkpoint_dir
    else:    
        checkpoint_dir = params.output_dir

    log_stats = {}
    

    for target_task in range(params.num_task):
        
        task_stats = {}
        
        for current_task in range(target_task, params.num_task):
            print(f'\n\n- Evaluting on task {target_task}, model from task {current_task}  ...')
            
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint-task{current_task}.pth')
            assert os.path.exists(checkpoint_path), f'No file at {checkpoint_path}'
            print(f'\nEvaluting checkpoints from {checkpoint_path}...\n')

            if params.param_isolation:
                model.load_isolated_checkpoint(task_num=current_task, checkpoint_path=checkpoint_path)
            else:
                model.reload_checkpoint(task_num=current_task, checkpoint_path=checkpoint_path)


            _, val_dataloader, _ = dataloaders[target_task]

            eval_stats = {
                'task': target_task,
                'current_task': current_task,
            }
            results = run_eval(model, val_dataloader)
            eval_stats.update({
                **results
            })
            print(eval_stats)

            task_stats.update({
                current_task: eval_stats
            })
        log_stats.update({
            target_task: task_stats,
        })

    return log_stats

            


@torch.no_grad()
def run_eval(model, loader, num_ignored_objects=1):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('ari', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mse', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mse_unmodified_fg', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mse_fg', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mean_segcover', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('scaled_segcover', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    header = 'Eval Metrics:'
    log_interval = len(loader) // 2

    model.eval()
    for inputs in metric_logger.log_every(loader, log_interval, header):
        image = inputs['image']
        image = image.cuda()

        with torch.no_grad():
            outputs = model(image)
            loss = outputs["loss"]

        true_mask = inputs["mask"].cpu().argmax(dim=1)
        pred_mask = outputs["mask"].cpu().argmax(dim=1, keepdim=True).squeeze(2)
        reconstruction = make_recon_img(outputs["slot"], outputs["mask"]).clamp(0.0, 1.0)

        mse_full = (image - reconstruction) ** 2
        mse = mse_full.mean([1, 2, 3])


        if outputs["mask"].shape[1] == 1:  
            ari = mean_segcover = scaled_segcover = torch.full(
                (true_mask.shape[0],), fill_value=torch.nan
            )
        else:

            ari = ari_(true_mask, pred_mask, num_ignored_objects)
            mean_segcover, scaled_segcover = segmentation_covering(
                true_mask, pred_mask, num_ignored_objects
            )

        
  
        unsqueezed_shape = (*inputs["is_foreground"].shape, 1, 1)
        is_fg = inputs["is_foreground"].view(*unsqueezed_shape)
        is_modified = inputs["is_modified"].view(*unsqueezed_shape)


        fg_mask = (inputs["mask"] * is_fg).sum(1)

        unmodified_fg_mask = (inputs["mask"] * is_fg * (1 - is_modified)).sum(1)

        fg_mse = (mse_full.cpu() * fg_mask).mean([1, 2, 3])

        unmodified_fg_mse = (mse_full.cpu() * unmodified_fg_mask).mean([1, 2, 3])

        metric_logger.update(loss=loss.item())
        metric_logger.update(ari=ari.mean().item())
        metric_logger.update(mse=mse.mean().item())
        metric_logger.update(mse_unmodified_fg=unmodified_fg_mse.mean().item())
        metric_logger.update(mse_fg=fg_mse.mean().item())
        metric_logger.update(mean_segcover=mean_segcover.mean().item())
        metric_logger.update(scaled_segcover=scaled_segcover.mean().item())

    metric_logger.synchronize_between_processes()

    print(">>> Eval Averaged stats:", metric_logger)

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    


def eval_metrics_continual(params, model, dataloaders):
    if params.checkpoint_dir != '':
        checkpoint_dir = params.checkpoint_dir
    else:    
        checkpoint_dir = params.output_dir

    log_stats = {}
    print(f'\nEvaluting checkpoints from {checkpoint_dir}...\n')

    for target_task in range(params.num_task):
        
        task_stats = {}
        
        for current_task in range(target_task, params.num_task):
            print(f'\n\n- Evaluting on task {target_task}, model from task {current_task}  ...')

            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint-task{current_task}.pth')
            assert os.path.exists(checkpoint_path), f'No file at {checkpoint_path}'
            print(f'\nEvaluting checkpoints from {checkpoint_path}...\n')


            if params.param_isolation:
                model.load_isolated_checkpoint(task_num=current_task, checkpoint_path=checkpoint_path)
            else:
                model.reload_checkpoint(task_num=current_task, checkpoint_path=checkpoint_path)


            _, val_dataloader, _ = dataloaders[current_task]

            eval_stats = {
                'task': target_task,
                'current_task': current_task,
            }
            results = run_eval(model, val_dataloader)
            eval_stats.update({
                **results
            })
            print(eval_stats)

            task_stats.update({
                current_task: eval_stats
            })
        log_stats.update({
            target_task: task_stats,
        })

    return log_stats