import init_path
import os
import numpy as np
import pandas as pd
from pycocotools.coco import COCO
import skimage.io as io
import nibabel as nib
import pickle
import seaborn as sns
import os
import re
import warnings
from tqdm import tqdm
import json
import random

from scipy import signal, stats
from scipy.interpolate import interp1d
from sklearn.model_selection import KFold

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl

from settings import settings
from utils import *
from filters import filters
from funcs import *

import timm.optim.optim_factory as optim_factory
from datasets_pretrain import HCP3TPretrainDataset, HCP7TPretrainDataset
from model_utils import CosineWarmupScheduler
from models_autoencoder import fMRIAutoEncoder

def default_model_args(replace=False):
    def _serialize(obj):
        if type(obj) is DataContainer:
            return obj.__dict__
    
    fName = 'HCP3T_pretrain.json'
    fPath = settings.CONFIG_DIR / fName
    if fPath.is_file() and (not replace):
        print('file exist, no overwrite')
        return

    default_args = DataContainer()
    pretrain_ids = ['pretrain1', 'pretrain2']
    for pretrain_id in pretrain_ids:
        args = DataContainer()
        """ model configs """
        args.model = DataContainer()
        args.model.num_rois = 100
        args.model.fmri_seg_size = 15
        args.model.fmri_num_segs = 20
        args.model.roi_embed_dim = 256  # roi embedding dimension
        args.model.fmri_embed_dim = 256 # fmri time-series embedding dimension
        args.model.mdl_embed_dim = 256  # model (general) embedding dimension (state embedding dimension)
        # TSN
        args.model.TSN_variants = 'MaskedTSN'
        args.model.TSN_num_heads = 8
        args.model.TSN_mlp_ratio = 4.
        args.model.TSN_norm_layer = 'LayerNorm'
        args.model.TSN_depth = 3
        args.model.TSN_self_inclusive = False
        args.model.TSN_attn_type = 'embedding'
        args.model.TSN_qk_proj = False

        # AutoEncoder
        args.model.AE_mlp_ratio = 4.
        args.model.AE_norm_layer = 'LayerNorm'
        args.model.AE_num_heads = 8
        args.model.AE_depth = 2
        args.model.AE_dec_depth = 2
        args.model.AE_dec_embed_dim = 256

        """ training settings """
        args.training = DataContainer()
        args.training.max_epoch = 4000
        args.training.save_epochs_interval = 500
        args.training.learning_rate = 1e-3
        args.training.min_learning_rate = 1e-7
        args.training.warmup_epochs = 50
        args.training.weight_decay = 1e-3
        args.training.batch_size = 32
        args.training.accum_iter = 4
        args.training.device = 'cuda'
        args.training.fmri_type = 'minimal_processed'
        
        setattr(default_args, pretrain_id, args)

    with open(str(fPath), "w") as outfile:
        json_string = json.dumps(default_args, default=_serialize, indent=4)
        outfile.write(json_string)

def load_model_args(pretrain_id, verbose=True):
    def _to_DataContainer(obj):
        if type(obj) is dict:
            obj_ = DataContainer()
            for key, d in obj.items():
                setattr(obj_, key, _to_DataContainer(d))
            return obj_
        else:
            return obj
        
    fName = 'HCP3T_pretrain.json'
    fPath = settings.CONFIG_DIR / fName
    with open(fPath, "r") as file:
        dat = json.load(file)
        dat = dat[pretrain_id]

    if verbose:
        print('***** Model configuration *****')
        print(json.dumps(dat, indent=4))

    dat = _to_DataContainer(dat)
    return dat

def initilization():
    PretrainedStateModels = DataContainer()

    var2save = ['PretrainedStateModels']
    for varName in var2save:
        fName = getattr(settings.projectData.files.general_HCP3T, varName)
        fPath = settings.projectData.dir.general_HCP3T / fName

        with open(fPath, 'wb') as handle:
            print('save {:s} to {:s} ...'.format(varName, str(fPath)))
            pickle.dump(eval(varName), handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('... done')

def pretrain_model_one_epoch(model, train_loader, val_loader, 
                             args, optimizer, scheduler):
    
    ###### Training ######
    model.train()
    optimizer.zero_grad()
    train_losses = []
    train_lrs = []    
    dataIter = iter(train_loader)

    for i_iter in tqdm(range(len(dataIter)), 
            desc='train epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):
        
        sldwins, fmri_segs = next(dataIter)
        fmri_segs = fmri_segs.to(args.training.device)
        
        mask_ratio = np.random.random() * args.training.max_mask_ratio
        loss, pred, mask, latent, (loss_masked, loss_unmasked) = model(fmri_segs, mask_ratio=mask_ratio, 
                                                                       mask_loss_ratio=args.training.mask_loss_ratio)
        loss /= args.training.accum_iter
        loss_masked /= args.training.accum_iter
        loss_unmasked /= args.training.accum_iter
        loss.backward()
        
        if (((i_iter + 1)) % args.training.accum_iter == 0) or ((i_iter + 1) == len(dataIter)):
            """ update model parameters """
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
        train_losses.append([loss.cpu().item(), loss_masked.cpu().item(), loss_unmasked.cpu().item()])
        train_lrs.append(scheduler.get_last_lr())
        
    ###### Validation ######
    model.eval()
    test_losses = []
    with torch.no_grad():
        dataIter = iter(val_loader)
        
        for i_iter in tqdm(range(len(dataIter)), 
                desc='test epoch [{:d}|{:d}]'.format(args.training.epoch, args.training.max_epoch)):
            
            sldwins, fmri_segs = next(dataIter)
            fmri_segs = fmri_segs.to(args.training.device)
            
            mask_ratio = np.random.random() * args.training.max_mask_ratio
            loss, pred, mask, latent, (loss_masked, loss_unmasked) = model(fmri_segs, mask_ratio=mask_ratio, 
                                                                           mask_loss_ratio=args.training.mask_loss_ratio)
            loss /= args.training.accum_iter
            loss_masked /= args.training.accum_iter
            loss_unmasked /= args.training.accum_iter
            
            test_losses.append([loss.cpu().item(), loss_masked.cpu().item(), loss_unmasked.cpu().item()])
            
    return np.array(train_losses), np.array(test_losses), np.array(train_lrs)

def pretrain_state_model(pretrain_id, random_seed=42):

    def _process_model_training(random_seed=42):
        def _save_model(model, savePath):
            print('save model to {:s} ...'.format(str(savePath)))
            torch.save(model.cpu().state_dict(), str(savePath))
            print('... done')
            model.to(args.training.device)

            return str(savePath)
        
        pl.seed_everything(random_seed)
        # Ensure that all operations are deterministic on GPU (if used) for reproducibility
        torch.backends.cudnn.determinstic = True
        torch.backends.cudnn.benchmark = False

        """ *** load model arguments *** """
        args = load_model_args(pretrain_id)
        args.model.TSN_norm_layer = getattr(nn, args.model.TSN_norm_layer)
        args.model.AE_norm_layer = getattr(nn, args.model.AE_norm_layer)
        args.training.eff_batch_size = args.training.batch_size * args.training.accum_iter

        """ *** datasets *** """
        train_dataset = HCP3TPretrainDataset(settings, region_roi='Yeo100Parc',
                                            fmri_type=args.training.fmri_type, 
                                            task_type=args.training.task_type,
                                            overlapping_segments=args.training.overlapping_segments,
                                            random_seed=random_seed)
        train_dataloader = DataLoader(train_dataset, batch_size=args.training.batch_size, shuffle=True, num_workers=4)

        test_dataset = HCP7TPretrainDataset(settings, region_roi='Yeo100Parc', use_all_data=True,
                                    fmri_type=args.training.fmri_type,
                                    overlapping_segments=args.training.overlapping_segments,
                                    random_seed=random_seed)
        test_dataset.train()
        test_dataloader = DataLoader(test_dataset, batch_size=args.training.batch_size, shuffle=False, num_workers=4)

        print('process HCP-3T pretraining: training samples {:d}, test samples {:d}, num rois {:d}'.format(
              len(train_dataset), len(test_dataset), train_dataset.num_rois))
        
        """ construct model """
        model = fMRIAutoEncoder(args.model)
        model = model.to(args.training.device)

        param_groups = optim_factory.param_groups_weight_decay(model, args.training.weight_decay)
        optimizer = torch.optim.AdamW(param_groups, lr=args.training.learning_rate, betas=(0.9, 0.95))

        max_iters = args.training.max_epoch * len(train_dataloader)
        args.training.warmup_iters = args.training.warmup_epochs * len(train_dataloader)
        scheduler = CosineWarmupScheduler(optimizer, 
                                        warmup=int(args.training.warmup_iters / args.training.accum_iter), 
                                        max_iters=np.ceil(max_iters / args.training.accum_iter), 
                                        base_lr=args.training.learning_rate, 
                                        min_lr=args.training.min_learning_rate)

        """ model training """
        train_loss_all = []
        test_loss_all = []
        train_lrs_all = []
        best_loss = np.inf

        saved_model_files = {}
        mdl_saveDir = settings.projectData.dir.general_HCP3T
        mdl_saveDir = mdl_saveDir / settings.projectData.rel_dir.general_HCP3T.Models
        for epoch in range(1, args.training.max_epoch+1):
            args.training.epoch = epoch
            
            train_losses, test_losses, train_lrs = pretrain_model_one_epoch(model, 
                                                            train_dataloader, 
                                                            test_dataloader, 
                                                            args,
                                                            optimizer, 
                                                            scheduler)
            train_losses = np.nanmean(train_losses, axis=0)
            test_losses = np.nanmean(test_losses, axis=0)
            s = '[HCP-3T pretrain] train loss = {:.3f}, test loss = {:.3f}, train lr = {:.4f}\n'.format(
                    train_losses[0], test_losses[0], train_lrs.mean())
            s+= '\t[Train]: Masked_Loss = {:.4f}, Unmasked_Loss = {:.4f}\n'.format(
                    train_losses[1], train_losses[2])
            s+= '\t [Test]: Masked_Loss = {:.4f}, Unmasked_Loss = {:.4f}\n'.format(
                    test_losses[1], test_losses[2])
            print(s)
            
            train_loss_all.append(train_losses)
            test_loss_all.append(test_losses)
            train_lrs_all.append(train_lrs)
            
            if test_losses[0] < best_loss:
                print('update best validation loss from {:.3f} to {:.3f}'.format(best_loss, test_losses[0]))
                best_loss = test_losses[0]
                
                if epoch > args.training.warmup_epochs:
                    mdl_fName = '{:s}_best.pt'.format(pretrain_id)
                    mdl_fPath = mdl_saveDir / mdl_fName
                    _save_model(model, mdl_fPath)
                    saved_model_files['best'] = str(mdl_fPath)
                
            if (epoch % args.training.save_epochs_interval == 0) or (epoch == args.training.max_epoch):
                mdl_fName = '{:s}_epoch_{:d}.pt'.format(pretrain_id, epoch)
                mdl_fPath = mdl_saveDir / mdl_fName
                _save_model(model, mdl_fPath)
                saved_model_files['epoch_{:d}'.format(epoch)] = str(mdl_fPath)
                
        saved_model_files = pd.DataFrame.from_dict(saved_model_files, 
                                                orient='index', columns=['file_path']).reset_index(names='checkpoint')
        
        modelInfo = DataContainer()
        modelInfo.args = args
        modelInfo.train_losses = np.array(train_loss_all)
        modelInfo.test_losses = np.array(test_loss_all)
        modelInfo.learning_rates = np.array(train_lrs_all)
        modelInfo.model_files = saved_model_files
        modelInfo.used_random_seed = random_seed

        torch.cuda.empty_cache()
        
        return modelInfo


    varName = 'PretrainedStateModels'
    fName = getattr(settings.projectData.files.general_HCP3T, varName)
    fPath = settings.projectData.dir.general_HCP3T / fName
    with open(fPath, 'rb') as handle:
        PretrainedStateModels = pickle.load(handle)

    pretrainInfo = _process_model_training(random_seed=random_seed)
    setattr(PretrainedStateModels, pretrain_id, pretrainInfo)

    var2save = ['PretrainedStateModels']
    for varName in var2save:
        fName = getattr(settings.projectData.files.general_HCP3T, varName)
        fPath = settings.projectData.dir.general_HCP3T / fName

        with open(fPath, 'wb') as handle:
            print('save {:s} to {:s} ...'.format(varName, str(fPath)))
            pickle.dump(eval(varName), handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('... done')

def proc_pretrain_state_models(resume=False):

    # pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv1')

    pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv2')

    # pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv3')

    # pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv4')

    # pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv5')

    # pretrain_state_model(pretrain_id='AE_MaskTSN_PRETRAINv6')

    # pretrain_state_model(pretrain_id='AE_DynamicMaskTSN_PRETRAINv1')

    pretrain_state_model(pretrain_id='AE_DynamicMaskTSN_PRETRAINv2')

    # pretrain_state_model(pretrain_id='AE_DynamicMaskTSN_PRETRAINv3')

    # pretrain_state_model(pretrain_id='AE_VanillaTSN_PRETRAINv1')

    # pretrain_state_model(pretrain_id='AE_VanillaTSN_PRETRAINv2')

    # pretrain_state_model(pretrain_id='AE_VanillaTSN_PRETRAINv3')

    # pretrain_state_model(pretrain_id='AE_VanillaTSN_PRETRAINv4')



def main():

    # initilization()

    # default_model_args(replace=True)

    proc_pretrain_state_models()

if __name__ == "__main__":
    main()