import os, sys
import numpy as np
import torch
import argparse
import datetime
import wandb
import torchvision.transforms as transforms
from einops import rearrange
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import copy
import os 
# own code
from config_cross import Config_Generative_Model
from dataset import create_Kamitani_dataset, create_BOLD5000_dataset_classify, create_BOLD5000_dataset_crosssub
from dc_ldm.ldm_for_fmri import fLDM
from dc_ldm.modules.encoders.adapter import Adapter 
from eval_metrics import get_similarity_metric
from sc_mbm.mae_for_fmri_cross import MAEforFMRIVQGAN

# import dill, multiprocessing
# dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads
# multiprocessing.reduction.ForkingPickler = dill.Pickler
# multiprocessing.reduction.dump = dill.dump
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

def wandb_init(config, output_path):
    wandb.init( project='wandn-project',
                group=config.group_name,
                name=f"{config.exp_name}-{config.phaseA_name}",
                anonymous="allow",
                config=config,
                reinit=True)
    create_readme(config, output_path)

def wandb_finish():
    wandb.finish()

def to_image(img):
    if img.shape[-1] != 3:
        img = rearrange(img, 'c h w -> h w c')
    img = 255. * img
    return Image.fromarray(img.astype(np.uint8))

def channel_last(img):
        if img.shape[-1] == 3:
            return img
        return rearrange(img, 'c h w -> h w c')

def get_eval_metric(samples, avg=True):
    metric_list = ['mse', 'pcc', 'ssim', 'psm']
    res_list = []
    
    gt_images = [img[0] for img in samples]
    gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c')
    samples_to_run = np.arange(1, len(samples[0])) if avg else [1]
    for m in metric_list:
        res_part = []
        for s in samples_to_run:
            pred_images = [img[s] for img in samples]
            pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c')
            res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m)
            res_part.append(np.mean(res))
        res_list.append(np.mean(res_part))     
    res_part = []
    for s in samples_to_run:
        pred_images = [img[s] for img in samples]
        pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c')
        res = get_similarity_metric(pred_images, gt_images, 'class', None, 
                        n_way=50, num_trials=50, top_k=1, device='cuda')
        res_part.append(np.mean(res))
    res_list.append(np.mean(res_part))
    res_list.append(np.max(res_part))
    metric_list.append('top-1-class')
    metric_list.append('top-1-class (max)')
    return res_list, metric_list
               
def generate_images(generative_model, fmri_latents_dataset_train, fmri_latents_dataset_test, config):
    grid, _ = generative_model.generate(fmri_latents_dataset_train, config.num_samples, 
                config.ddim_steps, config.HW, 10) # generate 10 instances
    grid_imgs = Image.fromarray(grid.astype(np.uint8))
    grid_imgs.save(os.path.join(config.output_path, 'samples_train.png'))
    wandb.log({'summary/samples_train': wandb.Image(grid_imgs)})

    grid, samples = generative_model.generate(fmri_latents_dataset_test, config.num_samples, 
                config.ddim_steps, config.HW)
    grid_imgs = Image.fromarray(grid.astype(np.uint8))
    grid_imgs.save(os.path.join(config.output_path,f'./samples_test.png'))
    for sp_idx, imgs in enumerate(samples):
        for copy_idx, img in enumerate(imgs[1:]):
            img = rearrange(img, 'c h w -> h w c')
            Image.fromarray(img).save(os.path.join(config.output_path, 
                            f'./test{sp_idx}-{copy_idx}.png'))

    wandb.log({f'summary/samples_test': wandb.Image(grid_imgs)})

    metric, metric_list = get_eval_metric(samples, avg=config.eval_avg)
    metric_dict = {f'summary/pair-wise_{k}':v for k, v in zip(metric_list[:-2], metric[:-2])}
    metric_dict[f'summary/{metric_list[-2]}'] = metric[-2]
    metric_dict[f'summary/{metric_list[-1]}'] = metric[-1]
    wandb.log(metric_dict)

def normalize(img):
    img = img / 255. 
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img

class random_crop:
    def __init__(self, size, p):
        self.size = size
        self.p = p
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.RandomCrop(size=(self.size, self.size))(img)
        return img

def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)

def main(config):
    # project setup
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)

    crop_pix = int(config.crop_ratio*config.img_size)

    img_transform_train = transforms.Compose([
        normalize,
        random_crop(config.img_size-crop_pix, p=0.5),
        transforms.Resize((256, 256)), 
        channel_last
    ])
    img_transform_test = transforms.Compose([
        normalize, transforms.Resize((256, 256)), 
        channel_last
    ])
    
    if config.dataset == 'GOD':
        fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, 
                fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 
                subjects=[config.kam_subs])
        num_voxels = fmri_latents_dataset_train.num_voxels

    elif config.dataset == 'BOLD5000':
        if not config.cross_sub:
            fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset_classify(config.bold5000_path, config.patch_size, 
                    fmri_transform=torch.FloatTensor, image_transform=[img_transform_train, img_transform_test], 
                    subjects=[config.bold5000_subs],target_sub_train_proportion=config.target_sub_train_proportion)
            num_voxels = fmri_latents_dataset_train.num_voxels
        else:

            train_subs = config.train_sub.split(",")
            test_subs = config.test_sub.split(",")
            fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset_crosssub(config.bold5000_path,config.patch_size,
            fmri_transform=torch.FloatTensor, image_transform=[img_transform_train,img_transform_test],train_subs=train_subs, test_subs=test_subs,
            target_sub_train_proportion= config.target_sub_train_proportion, cross_sub=config.cross_sub
            )
            num_voxels = fmri_latents_dataset_train.num_voxels
    else:
        raise NotImplementedError

    # prepare pretrained mbm 
    pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu')
    num_voxels = (pretrain_mbm_metafile['model']['pos_embed'].shape[1]-1) * pretrain_mbm_metafile['config'].patch_size

    if fmri_latents_dataset_train.fmri.shape[-1] < num_voxels:
        fmri_latents_dataset_train.fmri = np.pad(fmri_latents_dataset_train.fmri, ((0,0), (0, num_voxels - fmri_latents_dataset_train.fmri.shape[-1])), 'wrap')

    else:
        fmri_latents_dataset_train.fmri = fmri_latents_dataset_train.fmri[:, :num_voxels]
    
    if fmri_latents_dataset_test.fmri.shape[-1] < num_voxels:
        fmri_latents_dataset_test.fmri = np.pad(fmri_latents_dataset_test.fmri,((0,0),(0, num_voxels-fmri_latents_dataset_test.fmri.shape[-1])),'wrap')
    else:
        fmri_latents_dataset_test.fmri = fmri_latents_dataset_test.fmri[:, :num_voxels]
 
    # create generateive model
    if config.use_dis_vqgan_feature:
        assert config.fmri_vqgan_ckpt is not None 
        print("-->Loading dis vqgan from ",config.fmri_vqgan_ckpt)
        sd = torch.load(config.fmri_vqgan_ckpt, map_location="cpu")
        config_merge = sd['config_merge']
        fmri_vqgan_model = MAEforFMRIVQGAN(num_voxels=config_merge['num_voxels'], patch_size=config_merge['patch_size'], embed_dim=config_merge['embed_dim'],
                    decoder_embed_dim=config_merge['decoder_embed_dim'], depth=config_merge['depth'], 
                    num_heads=config_merge['num_heads'], decoder_num_heads=config_merge['decoder_num_heads'], 
                    mlp_ratio=config_merge['mlp_ratio'], focus_range=None, 
                    use_vqgan_recon_loss=config_merge['use_vqgan_recon_loss'], do_dec_mae=False, dropout_rate=config_merge['dropout_rate']
                    ) 
        fmri_vqgan_model.load_state_dict(sd['model'], strict=False)


        if config.use_adapter:
            adapter = Adapter(nums_rb=config.nums_rb,
                          cin = config.cin,
                          ksize = config.ksize,
                          sk = config.sk,
                          use_conv = config.use_conv)
        else:
            adapter = None 
    else:
        fmri_vqgan_model = None 
        adapter = None 
        
    generative_model = fLDM(pretrain_mbm_metafile, num_voxels,
                device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, 
                ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond,
                is_cross_mae=config.is_cross_mae, fmri_vqgan_model=fmri_vqgan_model,
                feature_adapter=adapter)
    
    # resume training if applicable
    if config.checkpoint_path is not None:
        model_meta = torch.load(config.checkpoint_path, map_location='cpu')
        generative_model.model.load_state_dict(model_meta['model_state_dict'])
        print('model resumed')
    # finetune the model
    trainer = create_trainer(config.num_epoch, config.precision, config.accumulate_grad, logger, check_val_every_n_epoch=5)
    generative_model.finetune(trainer, fmri_latents_dataset_train, fmri_latents_dataset_test,
                config.batch_size, config.lr, config.output_path, config=config)
  
    # generate images
    # generate limited train images and generate images for subjects seperately
    generate_images(generative_model, fmri_latents_dataset_train, fmri_latents_dataset_test, config)

    return

def get_args_parser():
    parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False)
    # project parameters
    parser.add_argument('--seed', type=int)
    parser.add_argument('--root_path', type=str)
    parser.add_argument('--kam_path', type=str)
    parser.add_argument('--bold5000_path', type=str)
    parser.add_argument('--pretrain_mbm_path', type=str)
    parser.add_argument('--crop_ratio', type=float)
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--bold5000_subs',type=str)
    parser.add_argument('--kam_subs',type=str)
    # finetune parameters
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--num_epoch', type=int)
    parser.add_argument('--precision', type=int)
    parser.add_argument('--accumulate_grad', type=int)
    parser.add_argument('--global_pool', type=bool)

    # diffusion sampling parameters
    parser.add_argument('--pretrain_gm_path', type=str)
    parser.add_argument('--num_samples', type=int)
    parser.add_argument('--ddim_steps', type=int)
    parser.add_argument('--use_time_cond', type=bool)
    parser.add_argument('--eval_avg', type=bool)

    parser.add_argument('--phaseA_name', type=str)
    parser.add_argument('--group_name',type=str, default="tune_ldm")
    parser.add_argument('--exp_name',type=str,default="experiment")

    parser.add_argument('--offline_wandb',action='store_true') # provided ->true otehrwise false
    parser.add_argument('--is_cross_mae',action='store_true')
    parser.add_argument('--cross_sub',action='store_true')
    parser.add_argument('--train_sub',type=str,default=None)
    parser.add_argument('--test_sub', type=str,default=None)
    parser.add_argument('--target_sub_train_proportion',type=float,default=1.0)
    parser.add_argument('--use_dis_vqgan_feature',action='store_true')
    parser.add_argument('--use_adapter',action='store_true')
    parser.add_argument('--fmri_vqgan_ckpt',type=str, default="")
    # # distributed training parameters
    # parser.add_argument('--local_rank', type=int)

    return parser

def update_config(args, config):
    for attr in config.__dict__:
        if hasattr(args, attr):
            if getattr(args, attr) != None:
                setattr(config, attr, getattr(args, attr))
    return config

def create_readme(config, path):
    print(config.__dict__)
    with open(os.path.join(path, 'README.md'), 'w+') as f:
        print(config.__dict__, file=f)


def create_trainer(num_epoch, precision=32, accumulate_grad_batches=2,logger=None,check_val_every_n_epoch=0):
    acc = 'gpu' if torch.cuda.is_available() else 'cpu'
    return pl.Trainer(accelerator=acc, max_epochs=num_epoch, logger=logger, 
            precision=precision, accumulate_grad_batches=accumulate_grad_batches,
            enable_checkpointing=False, enable_model_summary=False, gradient_clip_val=0.5,
            check_val_every_n_epoch=check_val_every_n_epoch,strategy="ddp")
  
if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    if args.offline_wandb:
        os.environ['WANDB_MODE'] = "offline"
    else:
        os.environ['WANDB_MODE'] = "online"
    
    config = Config_Generative_Model()
    config = update_config(args, config)

    if config.checkpoint_path is not None:
        model_meta = torch.load(config.checkpoint_path, map_location='cpu')
        ckp = config.checkpoint_path
        config = model_meta['config']
        config.checkpoint_path = ckp
        print('Resuming from checkpoint: {}'.format(config.checkpoint_path))

    output_path = os.path.join(config.root_path, 'results', 'generation',  '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
    config.output_path = output_path
    os.makedirs(output_path, exist_ok=True)
    
    wandb_init(config, output_path)

    logger = WandbLogger()
    config.logger = logger
    main(config)
    wandb_finish()
