import argparse, os, json
import random
import datetime, time
import numpy as np
from pathlib import Path
from collections import defaultdict

import torch
from torch import nn
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Subset

from utils.config import setup
from utils.loss_fn import FocalLoss
from utils.datasets import MMFlattenDataset, FeatureDataset
from utils.utils import NativeScalerWithGradNormCount as NativeScaler
from utils.utils import save_model, LARS, trunc_normal_, load_matching_state_dict
from utils.retrieval import retrieve_features

from models.trace import TRACE
from models.slimp import SLIMP
import models.vision_transformer as vits

from engine_linprobe import train_one_epoch, evaluate, extract_features, knn_classifier

def get_args_parser():
    parser = argparse.ArgumentParser('SLIMP pre-training', add_help=True)
    parser.add_argument('--config_file', default='./configs/eval/eval_padufes20.yaml', help='config file path')
    parser.add_argument('--output_dir', default='./results', help='path to save the output model')
    parser.add_argument('--checkpoint', default=None, help='load model from checkpoint')

    # [Optional] Perform evaluation only with a pre-trained linear head
    parser.add_argument('--eval', action='store_true', help='Perform linear evaluation only')

    # [Optional] Perform kNN evaluation only
    parser.add_argument('--knn_eval', action='store_true', help='Perform kNN evaluation only')
    parser.add_argument('--T', type=float, default=0.07, help='Temperature for kNN evaluation')

    # [Optional] Pseudo-modalities retrieval
    # 1. Extract features (run only once then run again with [--retrieve] instead)
    parser.add_argument('--extract_features', action='store_true', 
                        help='Save extracted features to disk for later use. It will extract features from the dataset specified in the config file. For retrieval we propose to use SLICE-3D/ISIC dataset.')
    # 2. Retrieve metadata features from SLICE-3D dataset)
    parser.add_argument('--retrieve', action='store_true', 
                        help='Perform pseudo-modalities retrieval. In this step run with the target dataset config file.')
    return parser

class LinearClassifier(nn.Module):
    def __init__(self, dim, num_labels=1):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, num_labels)
        # self.reset_params()

    def reset_params(self):
        nn.init.kaiming_normal_(self.linear.weight, nonlinearity='relu')
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

def main(args):
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing {device} device\n")

    # fix the seed for reproducibility
    set_seed(args.seed)

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        log_writer = None

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()])
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()])
    
    # for knn evaluation, ensure we use the same transforms as validation to extract features from trainset
    if args.knn_eval or args.extract_features:
        train_transform = val_transform
    
    ##################################
    # Optional train transformations #
    ##################################

    # train_transform = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.698, 0.522, 0.419],
    #                         std=[0.048, 0.054, 0.052])
    #     ])
    # val_transform = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.698, 0.522, 0.419],
    #                         std=[0.048, 0.054, 0.052])
    #     ])

    dataset_train = MMFlattenDataset(
        patient_csv=args.data.patient_tab_dir, 
        lesion_csv=args.data.lesion_tab_dir, 
        patient_target_csv=args.data.patient_target_dir,
        lesion_target_csv=args.data.lesion_target_dir, 
        image_dir=args.data.img_dir, 
        transform=train_transform,
        inner_only=args.inner_only,
        image_only=args.image_only, 
        val_split=args.data.val_ratio, 
        random_state=args.data_seed, 
        split='train',
        stratify=False)
    dataset_val = MMFlattenDataset(
        patient_csv=args.data.patient_tab_dir, 
        lesion_csv=args.data.lesion_tab_dir, 
        patient_target_csv=args.data.patient_target_dir,
        lesion_target_csv=args.data.lesion_target_dir, 
        image_dir=args.data.img_dir, 
        transform=val_transform, 
        random_state=args.data_seed,  
        val_split=args.data.val_ratio, 
        split='val',
        stratify=False)
    print(f'# training samples: {len(dataset_train)}')
    print(f'# validation samples: {len(dataset_val)}')

    def get_percentage_subset(dataset, percentage):
        label_to_indices = defaultdict(list)
        
        # Group indices by class
        for idx in range(len(dataset)):
            label = dataset[idx][-1]
            label_to_indices[int(label)].append(idx)
        
        for label, indices in label_to_indices.items():
            print(f"Class {label}: {len(indices)} samples")

        # Sample from each class to preserve the ratio
        selected_indices = []
        for label, indices in label_to_indices.items():
            n_select = round(len(indices) * percentage)
            selected_indices.extend(random.sample(indices, n_select))

        return Subset(dataset, selected_indices)

    # low-shot
    if args.data.train_subset_percentage > 0.0:
        dataset_train = get_percentage_subset(dataset_train, args.data.train_subset_percentage)

    data_loader_train = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=24)
    data_loader_val = torch.utils.data.DataLoader(
        dataset=dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=24)
    print(f'\n# training batches: {len(data_loader_train)}')
    print(f'# validation batches: {len(data_loader_val)}')
    
    if args.trace_pp.metadata_dir not in (None, 'None'):
        with open(args.trace_pp.metadata_dir, 'r') as f:     
            feature_metadata_pp = json.load(f)

        trace_pp = TRACE(hidden_size=args.trace_pp.hidden_size,           
                        feature_metadata=feature_metadata_pp,
                        num_indices=feature_metadata_pp["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pp.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pp.dropout,
                        cls_token=args.trace_pp.cls_token,
                        tran_layers=args.trace_pp.tran_layers,
                        heads=args.trace_pp.heads,
                        mlp_ratio=args.trace_pp.mlp_ratio,
                        use_num_norm=args.trace_pp.use_num_norm,
                        use_cat_norm=args.trace_pp.use_cat_norm,
                        checkbox_mode=args.trace_pp.checkbox_mode)
    else:   
        trace_pp=None
    
    if args.trace_pl.metadata_dir not in (None, 'None'):
        with open(args.trace_pl.metadata_dir, 'r') as f:
            feature_metadata_pl = json.load(f)
        
        trace_pl = TRACE(hidden_size=args.trace_pl.hidden_size,
                        feature_metadata=feature_metadata_pl,
                        num_indices=feature_metadata_pl["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pl.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pl.dropout,
                        cls_token=args.trace_pl.cls_token,
                        tran_layers=args.trace_pl.tran_layers,
                        heads=args.trace_pl.heads,
                        mlp_ratio=args.trace_pl.mlp_ratio,
                        use_num_norm=args.trace_pl.use_num_norm,
                        use_cat_norm=args.trace_pl.use_cat_norm,
                        checkbox_mode=args.trace_pl.checkbox_mode)
    else:
        trace_pl = None
    
    vit = vits.__dict__[args.vit.arch](patch_size=args.vit.patch_size, num_classes=0)
    
    model = SLIMP(
        tabular_model_patient=trace_pp, 
        tabular_model_lesion=trace_pl, 
        vit_model=vit, 
        d_model=vit.embed_dim)

    assert Path(args.checkpoint).exists(), f"Checkpoint path does not exist: {args.checkpoint}"
    if args.checkpoint and not args.eval and not args.retrieve:
        checkpoint = torch.load(args.checkpoint)
        print(f"\nLoad pre-trained checkpoint from: {args.checkpoint}")
        msg = model.load_state_dict(checkpoint['model'], strict=False)
        print(msg)

    if args.checkpoint and args.retrieve:
        print(f"\nLoad pre-trained checkpoint from: {args.checkpoint}")
        load_matching_state_dict(model, args.checkpoint)

    # Initialize Linear Classifier for linear probing
    if args.image_only:
        linear_input_dim = model.d_model
    elif args.inner_only:
        linear_input_dim = model.d_model * 2
    else:
        linear_input_dim = model.d_model * 3
    model.head = LinearClassifier(linear_input_dim, args.data.num_labels)
    
    # freeze all but the head
    if not args.finetune:
        for _, p in model.named_parameters():
            p.requires_grad = False
    for _, p in model.head.named_parameters():
        p.requires_grad = True

    model.to(device)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Number of trainable params:', n_parameters)
    
    lr = args.lr
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3)
    elif args.optimizer == 'lars':
        optimizer = LARS(model.parameters(), lr=lr, weight_decay=0)
    print(f'\nLR: {lr}')

    # Loss function
    if args.loss == 'bce':
        loss_fn = nn.BCEWithLogitsLoss()
    elif args.loss == 'focal':
        loss_fn = FocalLoss(alpha=args.focal.alpha)
    elif args.loss == 'ce':
        loss_fn = nn.CrossEntropyLoss()
    else:
        raise NotImplementedError(f"Loss function '{args.loss}' is not implemented.")
    loss_scaler = NativeScaler()

    if args.eval:
        checkpoint = torch.load(args.checkpoint)
        print(f"\nLoad pre-trained checkpoin from: {args.checkpoint}")
        msg = model.load_state_dict(checkpoint['model'], strict=True)
        print(msg)
        test_stats = evaluate(data_loader_val, model, loss_fn, device, args)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        exit(0)

    if args.knn_eval:
        print(f"\nStart kNN evaluation with T={args.T}")
        train_stats = extract_features(data_loader_train, model, device, args, return_targets_and_preds=True)
        test_stats = extract_features(data_loader_val, model, device, args, return_targets_and_preds=True)
        print(f"Train features shape: {train_stats['features'].shape}")
        print(f"Train targets shape: {train_stats['targets'].shape}")
        print(f"Test features shape: {test_stats['features'].shape}")
        print(f"Test targets shape: {test_stats['targets'].shape}")

        print("Features are ready!\nStart the k-NN classification.")
        train_features = train_stats['features'].cuda()
        test_features = test_stats['features'].cuda()
        train_labels = train_stats['targets'].cuda()
        test_labels = test_stats['targets'].cuda()

        train_features = nn.functional.normalize(train_features, dim=1, p=2)
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

        num_classes = args.data.num_labels + 1 if args.data.num_labels == 1 else args.data.num_labels
        for k in [5,10,15,20,50,100,200]:
            if num_classes == 2:
                top1, f1, bal_acc, auc = knn_classifier(train_features, train_labels, test_features, test_labels, k, T=args.T, num_classes=num_classes)
                print(f"{k}-NN classifier result: Acc: {top1}, F1: {f1}, Balanced Acc: {bal_acc}, AUC: {auc}")
            else:
                top1 = knn_classifier(train_features, train_labels, test_features, test_labels, k, T=args.T, num_classes=num_classes)
                print(f"{k}-NN classifier result: Acc: {top1}")
        exit(0)

    if args.extract_features:
        
        def split_features(features: torch.Tensor, args, dim_per_feature: int):
            if args.inner_only:
                # lesion + image
                z2, z3 = torch.split(features, dim_per_feature, dim=1)
                return None, z2, z3
            else:
                # patient + lesion + image
                z1, z2, z3 = torch.split(features, dim_per_feature, dim=1)
                return z1, z2, z3
        
        print("Extracting features...")
        train_stats = extract_features(data_loader_train, model, device, args, return_targets_and_preds=True)
        test_stats = extract_features(data_loader_val, model, device, args, return_targets_and_preds=True)

        train_features = train_stats['features'].cuda()
        test_features = test_stats['features'].cuda()
        print(f"Train features shape: {train_features.shape}")
        print(f"Test features shape: {test_features.shape}")

        # Do not normalize here. It will be done in the retrieval step.
        # train_features = nn.functional.normalize(train_features, dim=1, p=2)
        # test_features = nn.functional.normalize(test_features, dim=1, p=2)
        
        dim_per_feature = test_features.shape[1] // (3 if not args.inner_only and not args.image_only else 2 if args.inner_only else 1)
        
        feature_dir = os.path.join(args.output_dir, "features")
        os.makedirs(feature_dir, exist_ok=True)
        if args.image_only:
            torch.save(train_features, os.path.join(feature_dir, "train_features_image.pt"))
            torch.save(test_features, os.path.join(feature_dir, "test_features_image.pt"))
        else:
            z1_test, z2_test, _ = split_features(test_features, args, dim_per_feature)
            z1_train, z2_train, _ = split_features(train_features, args, dim_per_feature)
            if z1_train is not None:
                torch.save(z1_train, os.path.join(feature_dir, "train_features_patient.pt"))
            if z2_train is not None:
                torch.save(z2_train, os.path.join(feature_dir, "train_features_lesion.pt"))
            if z1_test is not None:
                torch.save(z1_test, os.path.join(feature_dir, "test_features_patient.pt"))
            if z2_test is not None:
                torch.save(z2_test, os.path.join(feature_dir, "test_features_lesion.pt"))
        print("Features are saved to disk! Run again with [--retrieve] to perform pseudo-modalities retrieval.")
        exit(0)
    
    if args.retrieve:
        print("\nStart pseudo-modalities retrieval")
        # Load pre-extracted features from SLICE-3D dataset
        lesion_path = os.path.join(args.output_dir, "features", "train_features_lesion.pt")
        patient_path = os.path.join(args.output_dir, "features", "train_features_patient.pt")

        if not os.path.exists(lesion_path):
            raise FileNotFoundError(
                f"Lesion feature file not found at {lesion_path}. "
                f"Please make sure you have extracted features from the SLICE-3D dataset."
            )

        if not os.path.exists(patient_path):
            raise FileNotFoundError(
                f"Patient feature file not found at {patient_path}. "
                f"Please make sure you have extracted features from the SLICE-3D dataset."
            )

        # If both files exist, load them
        lesion_db = torch.load(lesion_path).to(device)
        patient_db = torch.load(patient_path).to(device)

        image_train, lesion_train, patient_train, targets_train = retrieve_features(
            model, data_loader_train, lesion_db, patient_db, device
        )

        image_val, lesion_val, patient_val, targets_val = retrieve_features(
            model, data_loader_val, lesion_db, patient_db, device
        )

        args.inner_only = False
        args.image_only = False

        # overwrite the datasets and dataloaders with the retrieved features
        dataset_train = FeatureDataset(lesion_train, patient_train, image_train, targets_train)
        dataset_val = FeatureDataset(lesion_val, patient_val, image_val, targets_val)

        data_loader_train = torch.utils.data.DataLoader(
            dataset=dataset_train,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=24,
            drop_last=False)
        data_loader_val = torch.utils.data.DataLoader(
            dataset=dataset_val,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=24,
            drop_last=False)
        print("Pseudo-modalities retrieval completed!")

        model = model.head

    print(f"\nStart training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):

        train_stats = train_one_epoch(
            model, loss_fn, data_loader_train, 
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer, args=args)
        
        if args.output_dir and ((epoch + 1) % args.save_per_epochs == 0 or epoch + 1 == args.epochs):
            save_model(args=args,
                       model=model,
                       model_without_ddp=model,
                       optimizer=optimizer,
                       loss_scaler=loss_scaler,
                       epoch=epoch)
            
        test_stats = evaluate(data_loader_val, model, loss_fn, device, args)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        if log_writer is not None:
            for key, value in test_stats.items():
                log_writer.add_scalar(f'perf/test_{key}', value, epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir:
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    args = setup(args)
    main(args)