import os
import sys
import json
import numpy as np
import random
import argparse

import torch.nn.functional
from tqdm import tqdm

random.seed(42)

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

parser.add_argument(
    "--subjects",
    nargs='+',
    type=str,
    default=['subj01', 'subj02', 'subj05', 'subj07']
)

parser.add_argument(
    '--calc-mean-std',
    action='store_true',
)

args = parser.parse_args()


class OnlineStats:
    def __init__(self):
        self.n = 0
        self.mean = 0.
        self.M2 = 0.

    def update(self, x):
        self.n += 1
        delta = x - self.mean
        self.mean += delta / self.n
        delta2 = x - self.mean
        self.M2 += delta * delta2

    @property
    def variance(self):
        if self.n < 2:
            return float('nan')
        else:
            return self.M2 / (self.n - 1)


def main():
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'

    train_list = []
    test_list = []
    shape = (83, 104, 81)
    atlas = {}

    calc = OnlineStats()

    for subject in args.subjects:
        print(f'Processing {subject}...')
        source = json.load(open(f'{root_dir}/fmris/{subject}/{args.dataset}_fmri2image.json', 'r'))
        coco_caption = json.load(open(f'{root_dir}/{args.dataset}_captions.json', 'r'))
        atlas[subject] = f'{root_dir}/fmris/{subject}/atlas.json' if os.path.exists(f'{root_dir}/fmris/{subject}/atlas.json') else None

        train_dict = {}
        counter = 0
        for idx, image_id in enumerate(tqdm(source['train'], desc='train')):
            if image_id not in train_dict:
                train_dict[image_id] = {
                    'ids': counter,
                    'subject': subject,
                    'image': f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png',
                    'fmri': [f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_tr_{idx:06}.npy'],
                    'vision_embeds': f'{root_dir}/vision_embeds/vision_{image_id:06}.npy',
                    'caption': coco_caption[image_id]["captions"]
                }
                counter += 1
            else:
                train_dict[image_id]['fmri'].append(f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_tr_{idx:06}.npy')

            if args.calc_mean_std:
                x = np.load(f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_tr_{idx:06}.npy').astype(np.float32)

                if 'bold' in args.dataset:
                    x = x.transpose(2, 1, 0)

                if x.shape != shape:
                    x = torch.tensor(x).float().unsqueeze(0).unsqueeze(0)
                    x = torch.nn.functional.interpolate(x, size=shape, mode='trilinear', align_corners=False)
                    x = x.squeeze(0).squeeze(0)
                else:
                    x = torch.tensor(x).float()
                calc.update(x)

        val_dict = {}
        counter = 0
        for idx, image_id in enumerate(tqdm(source['val'], desc='val')):

            if image_id not in val_dict:
                val_dict[image_id] = {
                    'ids': counter,
                    'subject': subject,
                    'image': f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png',
                    'fmri': [f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_te_{idx:06}.npy'],
                    'vision_embeds': f'{root_dir}/vision_embeds/vision_{image_id:06}.npy',
                    'caption': coco_caption[image_id]["captions"]
                }
                counter += 1
            else:
                val_dict[image_id]['fmri'].append(f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_te_{idx:06}.npy')

            if args.calc_mean_std:
                x = np.load(f'{root_dir}/fmris/{subject}/whole/{args.dataset}_betas_te_{idx:06}.npy').astype(np.float32)

                if 'bold' in args.dataset:
                    x = x.transpose(2, 1, 0)

                if x.shape != shape:
                    x = torch.tensor(x).float().unsqueeze(0).unsqueeze(0)
                    x = torch.nn.functional.interpolate(x, size=shape, mode='trilinear', align_corners=False)
                    x = x.squeeze(0).squeeze(0)
                else:
                    x = torch.tensor(x).float()
                calc.update(x)

        train_list.extend(list(train_dict.values()))
        test_list.extend(list(val_dict.values()))

    if args.calc_mean_std:
        mean = calc.mean
        std = calc.variance.sqrt()

        np.save(f'{root_dir}/fmris/mean.npy', np.array(mean))
        np.save(f'{root_dir}/fmris/std.npy', np.array(std))

    with open(f'{root_dir}/fmris/pretrain_new.json', 'w') as f:
        json.dump({
            'mean': f'{root_dir}/fmris/mean.npy',
            'std': f'{root_dir}/fmris/std.npy',
            'atlas': atlas,
            'train': train_list,
            'val': test_list
        }, f, indent=4)


if __name__ == '__main__':
    main()
