import argparse
import os
from argparse import BooleanOptionalAction

import numpy as np
import torch
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline
from ned.ned import NED

from datasets.nsd import NaturalScenesDataset
from datasets.nsd_clip import CLIPExtractor, NSDCLIPFeaturesDataset
from IPA.ip_adapter import IPAdapter
from methods.subsets import SUBSETS
from methods.slerp import slerp
from methods.projection import project_modulation_vector
from methods.measurements import *
from methods.utils import resize
from methods.dino_encoder import EncoderModule
from torchvision import transforms


def main(cfg):

    print('#############################################')
    print(f'### Subject {cfg.subject} ROI {cfg.roi} ####')
    print('#############################################')

    folder_all = os.path.join(cfg.output_dir, f"{cfg.subject}_{cfg.roi}")

    # Initialize data structures
    clip_features_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 1024), dtype=np.float32)
    clip_features_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 1024), dtype=np.float32)
    ned_preds_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1), dtype=np.float32)
    ned_preds_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1), dtype=np.float32)
    dino_preds_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1), dtype=np.float32)
    dino_preds_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1), dtype=np.float32)
    depths_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    depths_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    normals_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 3, 64, 64), dtype=np.float32)
    normals_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 3, 64, 64), dtype=np.float32)
    curvatures_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    curvatures_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    brightness_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    brightness_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    saturation_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    saturation_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    warmth_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    warmth_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    entropy_max = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)
    entropy_min = np.empty((len(SUBSETS), cfg.num_images, cfg.num_frames+1, 64, 64), dtype=np.float32)

    # Load models
    diffusion_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(cfg.device)
    diffusion_pipeline.scheduler = DDIMScheduler.from_config(diffusion_pipeline.scheduler.config)
    diffusion_pipeline.safety_checker = None
    ip_model = IPAdapter(diffusion_pipeline, "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", os.path.join(cfg.models_dir, "ip-adapter_sd15.bin"), cfg.device)
    clip_extractor = CLIPExtractor(cfg.device)
    ned_object = NED(cfg.ned_dir)
    ned_encoder = ned_object.get_encoding_model(
        modality='fmri',
        train_dataset='nsd',
        model='fwrf',
        subject=cfg.subject,
        roi=cfg.roi,
    )
    ckpt_path = os.path.join(cfg.ckpt_dir, 'dino_vit', f'0{cfg.subject}_{cfg.roi}_0')
    ckpt_path = os.path.join(ckpt_path, sorted(list(os.listdir(ckpt_path)))[-1])
    dino_encoder = EncoderModule.load_from_checkpoint(ckpt_path, strict=False).to(cfg.device).eval()
    depth_estimator = DepthEstimator()
    surface_normal_estimator = SurfaceNormalEstimator(os.path.join(cfg.models_dir, 'rgb2normal_consistency.pth'))
    curvature_estimator = CurvatureEstimator()

    transform_dino = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Get modulation vectors from rest of subjects
    subjects = [1,2,3,4,5,6,7,8]
    subjects.remove(cfg.subject)
    modulation_vectors = []
    features = []
    for s in subjects:
        try:
            dataset_train = NSDCLIPFeaturesDataset(
                nsd=NaturalScenesDataset(
                    root=cfg.dataset_root,
                    subject=s,
                    partition='train',
                    hemisphere='both',
                    roi=cfg.roi,
                    return_average=True
                )
            )
        except:
            continue
        modulation_vectors.append(dataset_train.get_modulation_vector())
        features.append(dataset_train.features)
    modulation_vectors = np.stack(modulation_vectors, axis=0)
    modulation_vector_max = modulation_vectors.mean(axis=0)
    modulation_vector_max = modulation_vector_max / np.linalg.norm(modulation_vector_max)
    modulation_vector_min = -modulation_vector_max
    features = np.concatenate(features, axis=0)

    if cfg.projected:
        modulation_vector_max = project_modulation_vector(features, modulation_vector_max)
        modulation_vector_min = project_modulation_vector(features, modulation_vector_min)

    # Set strengths and seeds
    ts = np.linspace(0, 1, cfg.num_frames+1)
    ts = 1 / np.exp(ts**10)[::-1]
    ts = ts - ts.min()
    ts = ts / ts.max() * cfg.t1_strength
    strengths = ts

    seeds = torch.randint(-int(1e10), int(1e10), (cfg.num_images,), generator=torch.Generator().manual_seed(cfg.seed))

    for s_i, subset in enumerate(SUBSETS):

        # Get dataset sample
        dataset = NSDCLIPFeaturesDataset(
            nsd=NaturalScenesDataset(
                root=cfg.dataset_root,
                subject=cfg.subject,
                partition='test',
                hemisphere='both',
                roi=cfg.roi,
                return_average=True,
                subset=subset,
                tval_threshold=5,
            )
        )
        mean = dataset.nsd.activations.mean()
        dists_to_mean = np.abs(dataset.nsd.activations - mean)
        sample = np.argsort(dists_to_mean)[:cfg.num_images]

        for i, idx in enumerate(sorted(sample)):

            # Get source image
            image, source_img_embeds, _ = dataset[idx]
            _, _, coco_id = dataset.nsd[idx]
            image = image.resize((512, 512))

            folder_max = os.path.join(folder_all, subset, str(coco_id), 'max')
            folder_min = os.path.join(folder_all, subset, str(coco_id), 'min')
            os.makedirs(folder_max, exist_ok=True)
            os.makedirs(folder_min, exist_ok=True)
            os.makedirs(os.path.join(folder_all, 'measurements'), exist_ok=True)

            # Get embeddings
            endpoint_max = modulation_vector_max *  np.linalg.norm(source_img_embeds)
            endpoint_min = modulation_vector_min *  np.linalg.norm(source_img_embeds)
            target_max = torch.from_numpy(slerp(source_img_embeds, endpoint_max, cfg.num_frames+1, t1=cfg.t1_slerp)).unsqueeze(1).to(cfg.device)
            target_min = torch.from_numpy(slerp(source_img_embeds, endpoint_min, cfg.num_frames+1, t1=cfg.t1_slerp)).unsqueeze(1).to(cfg.device) 

            # Generate images
            for j, (st, tmax, tmin) in enumerate(zip(strengths, target_max, target_min)):
                
                # Maximization
                if j == 0:
                    img = image
                else:
                    img = ip_model.generate(
                        clip_image_embeds=tmax,
                        image=image,
                        strength=st,
                        num_samples=1,
                        num_inference_steps=50, 
                        seed=seeds[i].item()
                    )[0]
                img.save(os.path.join(folder_max, f"{j}.png"))
                clip_features_max[s_i,i,j] = clip_extractor(img).detach().cpu().numpy()
                ned_pred = ned_object.encode(ned_encoder, np.array(img).transpose(2, 0, 1)[None])[0].mean()
                dino_pred = dino_encoder(transform_dino(img).to(cfg.device).unsqueeze(0)).squeeze(0).detach().cpu().numpy().mean()
                ned_preds_max[s_i,i,j] = ned_pred
                dino_preds_max[s_i,i,j] = dino_pred
                depths_max[s_i,i,j] = resize(depth_estimator.compute(img), 64)[0]
                normals_max[s_i,i,j] = resize(surface_normal_estimator.compute(img), 64)
                curvatures_max[s_i,i,j] = resize(curvature_estimator.compute(img), 64)[0]
                brightness_max[s_i,i,j] = resize(compute_brightness(img), 64)[0]
                saturation_max[s_i,i,j] = resize(compute_saturation(img), 64)[0]
                warmth_max[s_i,i,j] = resize(compute_warmth(img), 64)[0]
                entropy_max[s_i,i,j] = resize(compute_entropy(img), 64)[0]

                # Minimization
                if j ==0:
                    img = image
                else:
                    img = ip_model.generate(
                        clip_image_embeds=tmin,
                        image=image,
                        strength=st,
                        num_samples=1,
                        num_inference_steps=50, 
                        seed=seeds[i].item()
                    )[0]
                img.save(os.path.join(folder_min, f"{j}.png"))
                clip_features_min[s_i,i,j] = clip_extractor(img).detach().cpu().numpy()
                ned_pred = ned_object.encode(ned_encoder, np.array(img).transpose(2, 0, 1)[None])[0].mean()
                dino_pred = dino_encoder(transform_dino(img).to(cfg.device).unsqueeze(0)).squeeze(0).detach().cpu().numpy().mean()
                ned_preds_min[s_i,i,j] = ned_pred
                dino_preds_min[s_i,i,j] = dino_pred
                depths_min[s_i,i,j] = resize(depth_estimator.compute(img), 64)[0]
                normals_min[s_i,i,j] = resize(surface_normal_estimator.compute(img), 64)
                curvatures_min[s_i,i,j] = resize(curvature_estimator.compute(img), 64)[0]
                brightness_min[s_i,i,j] = resize(compute_brightness(img), 64)[0]
                saturation_min[s_i,i,j] = resize(compute_saturation(img), 64)[0]
                warmth_min[s_i,i,j] = resize(compute_warmth(img), 64)[0]
                entropy_min[s_i,i,j] = resize(compute_entropy(img), 64)[0]

    np.save(os.path.join(folder_all, 'clip_features_max.npy'), clip_features_max)
    np.save(os.path.join(folder_all, 'clip_features_min.npy'), clip_features_min)
    np.save(os.path.join(folder_all, 'ned_preds_max.npy'), ned_preds_max)
    np.save(os.path.join(folder_all, 'ned_preds_min.npy'), ned_preds_min)
    np.save(os.path.join(folder_all, 'dino_preds_max.npy'), dino_preds_max)
    np.save(os.path.join(folder_all, 'dino_preds_min.npy'), dino_preds_min)
    np.save(os.path.join(folder_all, 'measurements', 'depths_max.npy'), depths_max)
    np.save(os.path.join(folder_all, 'measurements', 'depths_min.npy'), depths_min)
    np.save(os.path.join(folder_all, 'measurements', 'normals_max.npy'), normals_max)
    np.save(os.path.join(folder_all, 'measurements', 'normals_min.npy'), normals_min)
    np.save(os.path.join(folder_all, 'measurements', 'curvatures_max.npy'), curvatures_max)
    np.save(os.path.join(folder_all, 'measurements', 'curvatures_min.npy'), curvatures_min)
    np.save(os.path.join(folder_all, 'measurements', 'brightness_max.npy'), brightness_max)
    np.save(os.path.join(folder_all, 'measurements', 'brightness_min.npy'), brightness_min)
    np.save(os.path.join(folder_all, 'measurements', 'saturation_max.npy'), saturation_max)
    np.save(os.path.join(folder_all, 'measurements', 'saturation_min.npy'), saturation_min)
    np.save(os.path.join(folder_all, 'measurements', 'warmth_max.npy'), warmth_max)
    np.save(os.path.join(folder_all, 'measurements', 'warmth_min.npy'), warmth_min)
    np.save(os.path.join(folder_all, 'measurements', 'entropy_max.npy'), entropy_max)
    np.save(os.path.join(folder_all, 'measurements', 'entropy_min.npy'), entropy_min)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset_root", type=str, default="./data/NSD")
    parser.add_argument("--ned_dir", type=str, default='./data/NED')
    parser.add_argument("--models_dir", type=str, default='./data/models')
    parser.add_argument("--ckpt_dir", type=str, default="./data/checkpoints/")
    parser.add_argument("--output_dir", type=str, default='./outputs')
    parser.add_argument("--subject", type=int, default=1)
    parser.add_argument("--roi", default="PPA")
    parser.add_argument("--num_images", type=int, default=20)
    parser.add_argument("--num_frames", type=int, default=10)
    parser.add_argument("--t1_strength", type=float, default=0.6)
    parser.add_argument("--t1_slerp", type=float, default=1)
    parser.add_argument("--projected", action=BooleanOptionalAction, default=False)
    
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--device",
        type=str,
        default=(
            torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        ),
    )

    cfg = parser.parse_args()
    main(cfg)