import os
import pickle
from datetime import datetime
import numpy as np
import torch
import torchvision.transforms as transforms
import datasets
import pandas as pd
import random

import torchio as tio
from utils.spatial_transforms import ToTensor

from torchvision.transforms._transforms_video import (
    NormalizeVideo,
)

from importlib import import_module
from torch.utils.data import WeightedRandomSampler


def get_dataset(opt):
    data_setting = opt['data_setting']
    #mean = [.5024, .5024, .5024]
    #std = [.28973, .28973, .28973]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)
    if not opt['is_3d']:
        if data_setting['augment']:
            transform_train = transforms.Compose([
                #transforms.Resize(256),
                transforms.RandomHorizontalFlip(),
                #transforms.RandomRotation((-10, 10)),
                #transforms.RandomResizedCrop(size = 224, scale = (0.75, 1.0)),
                transforms.RandomCrop((224, 224)),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                #transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
    
        transform_test = transforms.Compose([
            #transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    
    else:
        mean_3d = [0.45, 0.45, 0.45]
        std_3d = [0.225, 0.225, 0.225]
        sizes = {'ADNI': (192, 192, 144), 'ADNI3T': (192, 192, 144), 'OCT': (224, 224, 100), 'RadFusion': (192, 192, 224), 'RadFusion4': (224, 224, 224)}
        if data_setting['augment']:
            transform_train = transforms.Compose([
                #tio.transforms.Resize((128, 128, 128)),
                #tio.transforms.Resize((128, 128, -1)),
                #tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
                #tio.transforms.CropOrPad((224, 224, 128)),
                tio.transforms.RandomFlip(),
                #tio.RescaleIntensity(out_min_max = (0, 1)),
                
                ToTensor(),
                NormalizeVideo(mean_3d, std_3d),
            ])
        else:
            transform_train = transforms.Compose([
                #tio.transforms.Resize((128, 128, 256)),
                #tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
                #tio.RescaleIntensity(out_min_max = (0, 1)),
                
                ToTensor(),
                NormalizeVideo(mean_3d, std_3d),
            ])
    
        transform_test = transforms.Compose([
            #tio.transforms.Resize((128, 128, -1)),
            #tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
            #tio.transforms.CropOrPad((224, 224, 128)),
            #tio.RescaleIntensity(out_min_max = (0, 1)),
            ToTensor(),
            NormalizeVideo(mean_3d, std_3d),
        ])
    
    
    g = torch.Generator()
    g.manual_seed(opt['random_seed'])

    def seed_worker(worker_id):
        np.random.seed(opt['random_seed'] )
        random.seed(opt['random_seed'])
        
    image_path = data_setting['image_feature_path']
    train_meta = pd.read_csv(data_setting['train_meta_path']) #, nrows = 2000
    val_meta = pd.read_csv(data_setting['val_meta_path'])
    if not opt['balanced_testing']:
        print('reading random split test set')
        test_meta = pd.read_csv(data_setting['test_meta_path'])   
    else:
        test_meta = pd.read_csv(data_setting['balanced_test_meta_path'])
        
    if opt['bianry_train_multi_test'] == -1:
        val_test_classes = opt['sens_classes']
    else:
        val_test_classes = opt['bianry_train_multi_test']
    
    # cal flatness
    if opt['cal_flatness']:
        if opt['experiment'] in ['LAFTR', 'CFair']:
            opt['sens_classes'] = max(opt['sens_classes'], opt['bianry_train_multi_test'])

    if not opt['is_3d']:
        # TODO: delete below for release
        if opt['dataset_name'] == 'HAM100004':
            opt['dataset_name'] = 'HAM10000'
        dataset_name = getattr(datasets, opt['dataset_name'])
        pickle_train_path = data_setting['pickle_train_path']
        pickle_val_path = data_setting['pickle_val_path']
        pickle_test_path = data_setting['pickle_test_path']
        train_data = dataset_name(train_meta, pickle_train_path, opt['sensitive_name'], opt['sens_classes'], transform_train, no_return_idx = opt['no_return_index'])
        val_data = dataset_name(val_meta, pickle_val_path, opt['sensitive_name'], val_test_classes, transform_test, False)
        test_data = dataset_name(test_meta, pickle_test_path, opt['sensitive_name'], val_test_classes, transform_test, False)
    else:
        if opt['dataset_name'] == 'ADNI3T':
            opt['dataset_name'] = 'ADNI'
        dataset_name = getattr(datasets, opt['dataset_name'])
        train_data = dataset_name(train_meta, image_path, opt['sensitive_name'], opt['sens_classes'], transform_train, opt['no_return_index'])
        val_data = dataset_name(val_meta, image_path, opt['sensitive_name'], val_test_classes, transform_test, False)
        test_data = dataset_name(test_meta, image_path, opt['sensitive_name'], val_test_classes, transform_test, False)
        
    print('loaded dataset ', opt['dataset_name'])
        
    if opt['fine_tune'] is True:
        prop = opt['finetuning_proportion']
        finetune_size = int(prop * len(train_data))
        idx = random.sample(range(0, len(train_data)), finetune_size)
        train_data = torch.utils.data.Subset(train_data, idx)
    
    if opt['experiment']=='resampling' or opt['experiment']=='GroupDRO' or opt['experiment']=='resamplingSWAD':
        weights = train_data.get_weights(resample_which = opt['resample_which'])
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True, generator = g)
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(
                            train_data, batch_size=opt['batch_size'], 
                            sampler=sampler,
                            shuffle=(opt['experiment']!='resampling' and opt['experiment']!='GroupDRO' and opt['experiment']!='resamplingSWAD'), num_workers=8, 
                            worker_init_fn=seed_worker, generator=g, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
                          val_data, batch_size=opt['batch_size'],
                          shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
                           test_data, batch_size=opt['batch_size'],
                           shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True)

    return train_data, val_data, test_data, train_loader, val_loader, test_loader, val_meta, test_meta
