import argparse, os, json
import random
import numpy as np
from pathlib import Path

import torch
from torchvision import transforms

from utils.config import setup
from utils.datasets import MMFlattenDataset

from models.trace import TRACE
from models.slimp import SLIMP
import models.vision_transformer as vits

from PIL import Image
import matplotlib

def get_args_parser():
    parser = argparse.ArgumentParser('Extract attention maps', add_help=True)
    parser.add_argument('--config_file', default='./configs/pretrain/pretrain_isic.yaml', help='config file path (same as pretraining)')
    parser.add_argument('--output_dir', default='./results', help='path to save the attention maps')
    parser.add_argument('--checkpoint', default=None, help='load model from checkpoint')

    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--local-rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser

def incoming_importance(tab_attn):
    """
    Incoming attention. “Which features receive the most attention from others?”
      - tab_attn: [N, T, T] including CLS at index 0
      - Drop CLS -> [N, T-1, T-1]
      - Zero diagonal, renormalize rows
      - Incoming attention = column-wise mean across query rows -> [N, T-1]
      - Finally average over N and normalize to sum=1
    Returns: [T-1] importance vector.
    """
    N, T, _ = tab_attn.shape
    assert T >= 2, "Expected at least CLS + 1 feature"
    A = tab_attn[:, 1:, 1:]  # drop CLS -> [N, T-1, T-1]

    eye = torch.eye(A.size(-1), device=A.device, dtype=A.dtype).bool()
    A_nodiag = A.masked_fill(eye, 0.0)
    A_rowsum = A_nodiag.sum(dim=-1, keepdim=True).clamp_min(1e-9)
    A_rowstoch = A_nodiag / A_rowsum  # row-stochastic

    # Incoming attention per sample: mean over rows (who is looking) => [N, T-1]
    imp_in_per_sample = A_rowstoch.mean(dim=1)

    # Aggregate once over all samples, then normalize to sum=1
    imp_mean = imp_in_per_sample.mean(dim=0)              # [T-1]
    imp_norm = imp_mean / imp_mean.sum().clamp_min(1e-12) # [T-1], sum=1
    return imp_norm

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

def main(args):
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing {device} device\n")

    # fix the seed for reproducibility
    set_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()])
    dataset = MMFlattenDataset(
        patient_csv=args.data.patient_tab_dir, 
        lesion_csv=args.data.lesion_tab_dir, 
        patient_target_csv=args.data.patient_target_dir,
        lesion_target_csv=args.data.lesion_target_dir, 
        image_dir=args.data.img_dir, 
        transform=transform,
        inner_only=args.inner_only,
        image_only=args.image_only, 
        val_split=args.data.val_ratio, 
        split='val',
        stratify=False)
    print(f'# total samples: {len(dataset)}')

    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        pin_memory=True,
        num_workers=24)
    print(f'\n# total batches: {len(data_loader)}')
    
    if args.trace_pp.metadata_dir not in (None, 'None'):
        with open(args.trace_pp.metadata_dir, 'r') as f:     
            feature_metadata_pp = json.load(f)

        trace_pp = TRACE(hidden_size=args.trace_pp.hidden_size,           
                        feature_metadata=feature_metadata_pp,
                        num_indices=feature_metadata_pp["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pp.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pp.dropout,
                        cls_token=args.trace_pp.cls_token,
                        tran_layers=args.trace_pp.tran_layers,
                        heads=args.trace_pp.heads,
                        mlp_ratio=args.trace_pp.mlp_ratio,
                        use_num_norm=args.trace_pp.use_num_norm,
                        use_cat_norm=args.trace_pp.use_cat_norm,
                        checkbox_mode=args.trace_pp.checkbox_mode)
    else:   
        trace_pp=None
    
    if args.trace_pl.metadata_dir not in (None, 'None'):
        with open(args.trace_pl.metadata_dir, 'r') as f:
            feature_metadata_pl = json.load(f)
        
        trace_pl = TRACE(hidden_size=args.trace_pl.hidden_size,
                        feature_metadata=feature_metadata_pl,
                        num_indices=feature_metadata_pl["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pl.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pl.dropout,
                        cls_token=args.trace_pl.cls_token,
                        tran_layers=args.trace_pl.tran_layers,
                        heads=args.trace_pl.heads,
                        mlp_ratio=args.trace_pl.mlp_ratio,
                        use_num_norm=args.trace_pl.use_num_norm,
                        use_cat_norm=args.trace_pl.use_cat_norm,
                        checkbox_mode=args.trace_pl.checkbox_mode)
    else:
        trace_pl = None
    
    vit = vits.__dict__[args.vit.arch](patch_size=args.vit.patch_size, num_classes=0)
    
    model = SLIMP(
        tabular_model_patient=trace_pp, 
        tabular_model_lesion=trace_pl, 
        vit_model=vit, 
        d_model=vit.embed_dim)

    assert Path(args.checkpoint).exists(), f"Checkpoint path does not exist: {args.checkpoint}"
    if args.checkpoint:
        #TODO: support attention maps from intial pretraining (SLICE-3D checkpoints)
        checkpoint = torch.load(args.checkpoint)
        print(f"\nLoad pre-trained checkpoint from: {args.checkpoint}")
        msg = model.load_state_dict(checkpoint['model'], strict=True)
        print(msg)
    
    for _, p in model.named_parameters():
        p.requires_grad = False

    model.to(device)
    model.eval()

    patient_attn_batches = []
    lesion_attn_batches  = []
    for batch_idx, batch in enumerate(data_loader):

        batch = [b.to(device, non_blocking=True) for b in batch]
        images = batch[-2]

        _, attentions = model(batch, args, return_attn=True)

        if len(attentions) == 3:
            patient_attn, lesion_attn, image_attn = attentions
        elif len(attentions) == 2:
            lesion_attn, image_attn = attentions
            patient_attn = None
        else:
            image_attn = attentions
            patient_attn = None
            lesion_attn = None
        
        if patient_attn is not None:
            patient_attn_batches.append(patient_attn.cpu())  # store on CPU to save VRAM

        if lesion_attn is not None:
            lesion_attn_batches.append(lesion_attn.cpu())
        
        #TODO: visualize tabular attention maps
        for i in range(image_attn.shape[0]):
            global_idx = batch_idx * data_loader.batch_size + i
            filename_original = f"{global_idx:06d}_rgb.png"
            filename_attmap = f"{global_idx:06d}_attmap.png"
            
            image_attention = image_attn[i, :, 0, 1:].mean(0).detach().cpu()
            image_attention = image_attention.reshape(1, 1, args.vit.img_size//args.vit.patch_size, args.vit.img_size//args.vit.patch_size)
            image_attention = image_attention/image_attention.sum()
            image_attention = (image_attention-(image_attention).min())/((image_attention).max()-(image_attention).min())
            image_attention = image_attention.squeeze()

            colormap = matplotlib.colormaps['viridis']
            image_attention = Image.fromarray(np.uint8(colormap(image_attention.numpy())*255)).convert('RGB') 
            image_attention = image_attention.resize((224, 224), resample=Image.NEAREST)
            image_attention.save(os.path.join(args.output_dir, filename_attmap))

            image_rgb = Image.fromarray(np.uint8(images[i].cpu().numpy().transpose(1, 2, 0)*255)).convert('RGB') 
            image_rgb.save(os.path.join(args.output_dir, filename_original))
    
    def _concat_or_none(tlist):
        if len(tlist) == 0:
            return None
        return torch.cat(tlist, dim=0)  # [N, T, T]

    patient_all = _concat_or_none(patient_attn_batches)
    lesion_all  = _concat_or_none(lesion_attn_batches)

    # Compute incoming importance per modality
    patient_importance = incoming_importance(patient_all) if patient_all is not None else None
    lesion_importance  = incoming_importance(lesion_all)  if lesion_all  is not None else None

    def print_importances(name, imp_vec, feature_names, sort=True, topk=None):
        """
        name: label to print (e.g., 'PATIENT' or 'LESION')
        imp_vec: torch.Tensor of shape [T] (already sum-normalized if you did that)
        feature_names: list[str] of length T, in the SAME order as imp_vec (no CLS)
        sort: if True, print in descending importance; if False, keep original order
        topk: optionally only print top-k rows
        Returns: (sorted_vals, sorted_idx, sorted_names)
        """
        if imp_vec is None:
            print(f"\n[{name}] No attention available; skipping.")
            return None, None, None

        T = imp_vec.numel()
        assert len(feature_names) == T, f"feature_names len {len(feature_names)} != imp_vec len {T}"

        if sort:
            idx = torch.argsort(imp_vec, descending=True)
        else:
            idx = torch.arange(T, device=imp_vec.device)

        vals = imp_vec[idx]
        names_sorted = [feature_names[i] for i in idx.tolist()]

        print(f"\n[{name}] Feature importances (sum=1){' [sorted]' if sort else ''}:")
        for r, (fname, v) in enumerate(zip(names_sorted, vals.tolist()), start=1):
            if topk is not None and r > topk:
                break
            if sort:
                print(f"  Rank {r:02d} | {fname}: {v:.6f}")
            else:
                print(f"  {fname}: {v:.6f}")

        return vals, idx, names_sorted

    if args.data.dataset_name == 'pad-ufes-20':
        patient_feature_names = ['age', 'lesion_per_patient', 'smoke', 'drink', 'background_father', 'background_mother', 'pesticide', 'gender', 'skin_cancer_history', 'cancer_history', 'has_piped_water', 'has_sewage_system', 'fitspatrick']
        lesion_feature_names = ['diameter_1', 'diameter_2', 'region', 'itch', 'grew', 'hurt', 'changed', 'bleed', 'elevation', 'biopsed']
    elif args.data.dataset_name == 'HAM_10000':
        patient_feature_names = ['age', 'lesion_per_patient', 'sex']
        lesion_feature_names = ['dx_type', 'localization']
    elif args.data.dataset_name == 'HIBA':
        patient_feature_names = ['age_approx', 'lesions_per_patient', 'family_hx_mm', 'fitzpatrick_skin_type', 'personal_hx_mm', 'sex']
        lesion_feature_names = ['anatom_site_general', 'concomitant_biopsy', 'dermoscopic_type', 'diagnosis_confirm_type', 'image_type']
    elif args.data.dataset_name == 'isic':
        patient_feature_names = ['age_approx', 'lesion_per_patient', 'sex']
        lesion_feature_names = ['clin_size_long_diam_mm', 'tbp_lv_A', 'tbp_lv_Aext', 'tbp_lv_B', 'tbp_lv_Bext', 'tbp_lv_C', 'tbp_lv_Cext', 
                                'tbp_lv_H', 'tbp_lv_Hext', 'tbp_lv_L', 'tbp_lv_Lext', 'tbp_lv_areaMM2', 'tbp_lv_area_perim_ratio', 
                                'tbp_lv_color_std_mean', 'tbp_lv_deltaA', 'tbp_lv_deltaB', 'tbp_lv_deltaL', 'tbp_lv_deltaLB', 
                                'tbp_lv_deltaLBnorm', 'tbp_lv_eccentricity', 'tbp_lv_minorAxisMM', 'tbp_lv_nevi_confidence', 
                                'tbp_lv_norm_border', 'tbp_lv_norm_color', 'tbp_lv_perimeterMM', 'tbp_lv_radial_color_std_max', 
                                'tbp_lv_stdL', 'tbp_lv_stdLExt', 'tbp_lv_symm_2axis', 'tbp_lv_symm_2axis_angle', 'tbp_lv_x', 
                                'tbp_lv_y', 'tbp_lv_z', 'anatom_site_general', 'tbp_tile_type', 'tbp_lv_location', 'tbp_lv_location_simple']
    else:
        raise ValueError(f"Unknown dataset_name: {args.data.dataset_name}")
    
    patient_vals, patient_idx, patient_names_sorted = print_importances(
        "PATIENT", patient_importance, patient_feature_names, sort=False
    )

    lesion_vals, lesion_idx, lesion_names_sorted = print_importances(
        "LESION", lesion_importance, lesion_feature_names, sort=False
    )


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    args = setup(args)
    main(args)