import argparse, os, json
import random
import datetime, time
import numpy as np
from pathlib import Path

import torch
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

from utils.config import setup
from utils.loss_fn import ContrastiveLoss
from utils.datasets import MMFlattenDataset, MMNestedDataset, mmnested_collate_fn
from utils.utils import NativeScalerWithGradNormCount as NativeScaler
from utils.utils import save_model, summarize_model, load_matching_state_dict
from utils import utils

from models.trace import TRACE
from models.slimp import SLIMP
import models.vision_transformer as vits

from engine_pretrain import train_one_epoch

def get_args_parser():
    parser = argparse.ArgumentParser('SLIMP pre-training', add_help=True)
    parser.add_argument('--config_file', default='./configs/pretrain/pretrain_isic.yaml', help='config file path')
    parser.add_argument('--output_dir', default='./results', help='path to save the output model')
    parser.add_argument('--continual', action='store_true', help='continual training on a target dataset')
    parser.add_argument('--checkpoint', default=None, help='load model from checkpoint')

    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--local-rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

def main(args):
    utils.init_distributed_mode(args)
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing {device} device\n")

    # fix the seed for reproducibility
    set_seed(args.seed + utils.get_rank())

    #TODO: CONSIDER VALIDATION SET IN PRETRAINING
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()])
    
    ##################################
    # Optional train transformations #
    ##################################
    
    # train_transform = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.RandomHorizontalFlip(p=0.3),
    #     transforms.RandomVerticalFlip(p=0.3),
    #     transforms.RandomApply([
    #         transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.15))
    #     ], p=0.2),
    #     transforms.RandomApply([
    #         transforms.RandomRotation(degrees=15)
    #     ], p=0.2),
    #     # transforms.RandomApply([
    #     #     transforms.GaussianBlur(kernel_size=3)
    #     # ], p=0.1),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.698, 0.522, 0.419],
    #                         std=[0.048, 0.054, 0.052])
    # ])
    
    if args.trace_pp.metadata_dir in (None, 'None'):
        dataset_train = MMFlattenDataset(
            patient_csv=None, 
            lesion_csv=args.data.lesion_tab_dir,
            patient_target_csv=None,
            lesion_target_csv=args.data.lesion_target_dir, 
            image_dir=args.data.img_dir, 
            transform=train_transform, 
            random_state=args.data_seed,
            val_split=args.data.val_ratio, 
            split=args.data.split)
        print(f'Number of lesions: {len(dataset_train)}')
    else:
        dataset_train = MMNestedDataset(
            patient_csv=args.data.patient_tab_dir, 
            lesion_csv=args.data.lesion_tab_dir, 
            patient_target_csv=args.data.patient_target_dir,
            lesion_target_csv=args.data.lesion_target_dir, 
            image_dir=args.data.img_dir, 
            transform=train_transform, 
            N=args.data.max_num_lesions, 
            random_state=args.data_seed,
            val_split=args.data.val_ratio, 
            split=args.data.split)
        print(f'Number of patients: {len(dataset_train)}')
        print(f'Number of max lesions per patient: {args.data.max_num_lesions}')

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
        print("Sampler_train = %s" % str(sampler_train))
        #TODO: update if validation set is used
    else:
        global_rank = 0
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if args.trace_pp.metadata_dir in (None, 'None'):
        data_loader_train = torch.utils.data.DataLoader(
            dataset=dataset_train,
            sampler=sampler_train,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=24)
    else:
        data_loader_train = torch.utils.data.DataLoader(
            dataset=dataset_train,
            sampler=sampler_train,
            batch_size=args.batch_size,
            collate_fn=mmnested_collate_fn,
            pin_memory=True,
            num_workers=24)
    
    if args.output_dir is not None and global_rank == 0:
        log_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        log_writer = None

    if args.trace_pp.metadata_dir not in (None, 'None'):
        with open(args.trace_pp.metadata_dir, 'r') as f:     
            feature_metadata_pp = json.load(f)

        trace_pp = TRACE(hidden_size=args.trace_pp.hidden_size,           
                        feature_metadata=feature_metadata_pp,
                        num_indices=feature_metadata_pp["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pp.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pp.dropout,
                        cls_token=args.trace_pp.cls_token,
                        tran_layers=args.trace_pp.tran_layers,
                        heads=args.trace_pp.heads,
                        mlp_ratio=args.trace_pp.mlp_ratio,
                        use_num_norm=args.trace_pp.use_num_norm,
                        use_cat_norm=args.trace_pp.use_cat_norm,
                        checkbox_mode=args.trace_pp.checkbox_mode)
    else:   
        trace_pp=None
    
    if args.trace_pl.metadata_dir not in (None, 'None'):
        with open(args.trace_pl.metadata_dir, 'r') as f:
            feature_metadata_pl = json.load(f)
        
        trace_pl = TRACE(hidden_size=args.trace_pl.hidden_size,
                        feature_metadata=feature_metadata_pl,
                        num_indices=feature_metadata_pl["num_indices"],
                        feature_extractor=True,
                        num_mode=args.trace_pl.num_mode,
                        num_labels=args.data.num_labels,
                        dropout_p=args.trace_pl.dropout,
                        cls_token=args.trace_pl.cls_token,
                        tran_layers=args.trace_pl.tran_layers,
                        heads=args.trace_pl.heads,
                        mlp_ratio=args.trace_pl.mlp_ratio,
                        use_num_norm=args.trace_pl.use_num_norm,
                        use_cat_norm=args.trace_pl.use_cat_norm,
                        checkbox_mode=args.trace_pl.checkbox_mode)
    else:
        trace_pl = None
    
    vit = vits.__dict__[args.vit.arch](patch_size=args.vit.patch_size, num_classes=0)
    
    # Loss function
    if args.loss == 'contrastive':
        loss_fn = ContrastiveLoss()
    else:
        raise NotImplementedError(f"Loss function '{args.loss}' is not implemented.")
    
    model = SLIMP(
        tabular_model_patient=trace_pp, 
        tabular_model_lesion=trace_pl, 
        vit_model=vit, 
        d_model=vit.embed_dim,
        loss_fn=loss_fn,
        lambda_outer=args.lambda_outer)

    model.to(device)
    model_without_ddp = model

    eff_batch_size = args.batch_size * args.accum_iter * utils.get_world_size()

    # args.lr = args.lr * eff_batch_size / 256
    print(f"\naccumulate grad iterations: {args.accum_iter}")
    print(f"effective batch size: {eff_batch_size}")
    print(f'lr: {args.lr}')

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model_without_ddp.parameters(), lr=args.lr)
    elif args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model_without_ddp.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model_without_ddp.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
    loss_scaler = NativeScaler()

    if args.checkpoint and args.continual:
        print(f"Continual pre-training from checkpoint {args.checkpoint}")
        load_matching_state_dict(model_without_ddp, args.checkpoint)

    if args.checkpoint and not args.continual:
        checkpoint = torch.load(args.checkpoint)
        model_without_ddp.load_state_dict(checkpoint['model'])
        print("Resume checkpoint %s" % args.checkpoint)
        if 'optimizer' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.vit.freeze:
        for p in model_without_ddp.vit.parameters():
            p.requires_grad = False
        # Unfreeze the params of patch embedding
        for p in model_without_ddp.vit.patch_embed.parameters():
            p.requires_grad = True

        # Freeze/Unfreeze entire trace_patient and trace_lesion
        if trace_pp is not None and args.trace_pp.freeze:
            for p in model_without_ddp.trace_patient.parameters():
                p.requires_grad = False
        if trace_pl is not None and args.trace_pl.freeze:
            for p in model_without_ddp.trace_lesion.parameters():
                p.requires_grad = False
        # Unfreeze specific submodules
        for name, module in [
            ("trace_patient.embeddings", getattr(model_without_ddp.trace_patient, "embeddings", None)),
            ("trace_lesion.embeddings", getattr(model_without_ddp.trace_lesion, "embeddings", None)),
            ("trace_patient.num_mlp", getattr(model_without_ddp.trace_patient, "num_mlp", None)),
            ("trace_lesion.num_mlp", getattr(model_without_ddp.trace_lesion, "num_mlp", None)),
        ]:
            if module is not None:
                for p in module.parameters():
                    p.requires_grad = True

    summarize_model(model_without_ddp)

    print(f"\nStart training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, data_loader_train, optimizer, 
            device, epoch, loss_scaler=loss_scaler,
            log_writer=log_writer, args=args)
        
        if args.output_dir and ((epoch + 1) % args.save_per_epochs == 0 or epoch + 1 == args.epochs):
            save_model(args=args,
                       model=model,
                       model_without_ddp=model_without_ddp,
                       optimizer=optimizer,
                       loss_scaler=loss_scaler,
                       epoch=epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch,}

        if args.output_dir:
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    args = setup(args)
    main(args)