import argparse
import copy
import gc

import numpy as np
import yaml

from dataset.carpk_dataset import CARPKDataset
from dataset.dataset_util import clip_transform
from dataset.fsc_dataset import FSCDataset
from dataset.image_dataset import ImageDataset
from dataset.shanghai_tech_dataset import ShanghaiTechDataset
from dataset.countbench_dataset import CountBenchDataset
from dreamsim.dreamsim.feature_extraction.extractor import ViTExtractor
from training.train import LightningPerceptualModel
from models.backbones import *
from models.knn import run_knn
from models.lgfbs_gpu import lfbgs_gpu, prep_features
from models.lgfbs_cpu import run_lgfbs_cpu
# from models.lgfbs_gpu import prep_features, lfbgs_gpu
from scripts import pidfile
from scripts.util import feature_extract_and_save, set_all_seeds
import torch

from dreamsim.util.train_utils import seed_worker

torch.multiprocessing.set_sharing_strategy('file_system')

def load_features(loader, m, embed_path, train=False):
    # if os.path.exists(embed_path):
    #     print(f"Loading features from {embed_path}")
    #     embed_dict = np.load(embed_path)
    #     embeds, labels = embed_dict['embeds'], embed_dict['labels']
    #     embeds = torch.nn.functional.normalize(torch.Tensor(embeds), dim=1, p=2).numpy()
    #     if len(embeds.shape) > 2:
    #         print('Detecting spatial embeddings, averaging over spatial dimensions')
    #         embeds = embeds.mean(axis=(2, 3))
    #     return embeds, labels
    # else:
    print(f"Computing features and saving to {embed_path}")
    return feature_extract_and_save(
        loader,
        m,
        embed_path,
        im_paths=False,
        train=train
    )


def get_class_to_id_mapping(train_root, train_ds):
    # class_json - should contain a mapping of class names to integer IDs
    subfolders = os.listdir(train_root)
    try: 
        # check if subfolders are numeric; if so take these to be the IDs
        class_to_idx = {}
        for subfolder in subfolders:
            class_to_idx[subfolder] = int(subfolder)
    except:
        # look for a json mapping classes to ids
        with open(f"{args.class_info_path}/{args.train_ds}_classes.json", "r") as f:
            class_to_idx = json.load(f)
    return class_to_idx
    
def build_loaders(task, preprocess, args):
    if task == 'vtab':
        class_to_idx = get_class_to_id_mapping(args.train_root, args.train_ds)
        train_dataset = ImageDataset(class_to_idx=class_to_idx, root=args.train_root, transform=preprocess, ret_path=False)
        if len(train_dataset) > 10000:
            train_dataset = torch.utils.data.Subset(train_dataset, np.random.choice(len(train_dataset), 10000, replace=False))

        test_datasets = {}
        if type(args.test_datasets) == str:
            args.test_datasets = args.test_datasets.split(",")
        for ds_name in args.test_datasets:
            ds_root = os.path.join(args.test_root, ds_name)
            test_datasets[ds_name] = ImageDataset(ds_root, class_to_idx, transform=preprocess, ret_path=False)
            if len(test_datasets[ds_name]) > 10000:
                test_datasets[ds_name] = torch.utils.data.Subset(test_datasets[ds_name], np.random.choice(len(test_datasets[ds_name]), 10000, replace=False))
    elif task == 'fsc147':
        train_dataset = FSCDataset(root=args.train_root, split='train', transform=preprocess, ret_path=False)
        test_datasets = {}
        if type(args.test_datasets) == str:
            args.test_datasets = args.test_datasets.split(",")
        for ds_name in args.test_datasets:
            test_datasets[ds_name] = FSCDataset(root=args.test_root, split='val', transform=preprocess, ret_path=False)
    elif task == 'carpk':
        train_dataset = CARPKDataset(root=args.train_root, split='train', transform=preprocess)
        test_datasets = {}
        if type(args.test_datasets) == str:
            args.test_datasets = args.test_datasets.split(",")
        for ds_name in args.test_datasets:
            test_datasets[ds_name] = CARPKDataset(root=args.test_root, split='test', transform=preprocess)
    elif task == 'shanghai_tech':
        train_dataset = ShanghaiTechDataset(root=args.train_root, part='B', split='train', transform=preprocess)
        test_datasets = {}
        if type(args.test_datasets) == str:
            args.test_datasets = args.test_datasets.split(",")
        for ds_name in args.test_datasets:
            test_datasets[ds_name] = ShanghaiTechDataset(root=args.test_root, part='A', split='test', transform=preprocess)
    elif task == 'countbench':
        train_dataset = CountBenchDataset(root=args.train_root, split='train', transform=preprocess)
        test_datasets = {}
        if type(args.test_datasets) == str:
            args.test_datasets = args.test_datasets.split(",")
        for ds_name in args.test_datasets:
            test_datasets[ds_name] = CountBenchDataset(root=args.train_root, split='train', transform=preprocess)

    g = torch.Generator()
    g.manual_seed(args.seed)
    train_loader = iter(torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               # worker_init_fn=seed_worker,
                                               # generator=g
                                                    ))

    test_loaders = {}
    for ds_name, ds in test_datasets.items():
        test_loaders[ds_name] = iter(torch.utils.data.DataLoader(ds,
                                                            batch_size=1,
                                                            shuffle=False,
                                                            num_workers=args.workers,
                                                            pin_memory=True))
    return train_loader, test_loaders
def run(args, m, model_name, preprocess, dist, suffix, device):
    torch.cuda.empty_cache()

    task = 'vtab'

    # print(f"Doing {len(class_to_idx)}-way classification")
    train_iters, test_iters = build_loaders(task, preprocess, args)

    ## Load precomputed embeddings if they exist, otherwise compute embeddings for all the datasets
    train_ds_name = args.train_ds
    test_embeds, test_labels, test_paths = {}, {}, {}
    for dsname, test_loader in test_iters.items():
        test_embeds[dsname], test_labels[dsname] = load_features(
            test_loader,
            m,
            os.path.join(args.embed_path, f"{model_name}_{train_ds_name}{suffix}_{dsname.split('/')[-1]}.npz")
            )
        test_embeds[dsname] = test_embeds[dsname].astype(np.float32)
        print(test_embeds[dsname].shape)

    train_embeds, train_labels = load_features(
        train_iters,
        m,
        os.path.join(args.embed_path, f"{model_name}_{train_ds_name}{suffix}_train.npz"),
        train=True
    )
    train_embeds = train_embeds.astype(np.float32)

    print(train_embeds.shape, train_labels.shape)

    ## Split train into train and val
    # pdb.set_trace()
    set_all_seeds(args.seed)
    val_size = int(train_embeds.shape[0] * args.val_ratio)
    val_indices = np.random.permutation(train_embeds.shape[0])[:val_size]
    val_embeds, val_labels = train_embeds[val_indices], train_labels[val_indices]
    train_mask = ~np.in1d(np.arange(train_embeds.shape[0]), val_indices)
    train_embeds, train_labels = train_embeds[train_mask], train_labels[train_mask]
    # print(train_embeds.shape, train_labels.shape, val_embeds.shape, val_labels.shape)
    print(f"{len(val_indices)} val examples and {train_embeds.shape[0]} train examples")

    ## Run classification
    if args.classifier == "knn":
        ## Run KNN classification on the synthetic(+real) dataset. TODO: sweep over k using validation set
        train_embeds, train_labels = torch.from_numpy(train_embeds).to('cuda'), torch.from_numpy(train_labels).to('cuda')
        val_embeds, val_labels = torch.from_numpy(val_embeds).to('cuda'), torch.from_numpy(val_labels).to('cuda')
        for ds_name in test_embeds.keys():
            test_embeds[ds_name], test_labels[ds_name] = torch.from_numpy(test_embeds[ds_name]).to('cuda'), torch.from_numpy(test_labels[ds_name]).to('cuda')
        test_results, val_results, best_param = run_knn(
            train_embeds, 
            train_labels,
            test_embeds, 
            test_labels, 
            dist_fn=dist, 
            cfm=args.confusion_matrix,
            val_embed_stack=val_embeds,
            val_label_stack=val_labels,
            model_name=model_name
        )

    elif args.classifier == "lgfbs":
        data_dict, test_dict = prep_features(
            train_embeds,
            train_labels,
            val_embeds,
            val_labels,
            test_embeds,
            test_labels,
            normalize=True)

        test_results, val_results, best_param = lfbgs_gpu(data_dict, test_dict, args.confusion_matrix)
    elif args.classifier == "lgfbs_cpu":
        data_dict, test_dict = prep_features(
            train_embeds,
            train_labels,
            val_embeds,
            val_labels,
            test_embeds,
            test_labels,
            normalize=True)
        test_results, val_results, best_param = run_lgfbs_cpu(data_dict, test_dict)
    else:
        raise ValueError(f"No classifier {args.classifier}")

    all_results = {
        "val": val_results,
        "test": test_results,
        "best_param": best_param
    }

    return all_results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    ## Data
    parser.add_argument("--train_ds", type=str, default="imagenet", help="training dataset name")
    parser.add_argument("--train_root", type=str, default="./train", help="root for training data")
    parser.add_argument("--test_root", type=str, default="./test",  help="Directory with eval datasets")
    parser.add_argument("--test_datasets", type=str,default="imagenet",
                        help="comma-separated list of test datasets; should be subfolders in args.test_root")
    parser.add_argument("--output_path", type=str, default="./outputs_scaling_old")
    parser.add_argument("--embed_path", type=str, default="./embeds",  help="path to cache precomputed embeddings")
    parser.add_argument("--class_info_path", type=str, default="./configs",  help="path to json with class-name-to-ID mapping (if needed)")

    ## Model
    parser.add_argument("--model", type=str, default="all")
    parser.add_argument("--classifier", type=str, default="knn", choices=["knn", "lgfbs", "lgfbs_cpu"])
    parser.add_argument("--confusion_matrix", type=bool, default=False)

    ## Experiment setup
    parser.add_argument("--batch_size", type=int, default=1, help="batch size for computing embeddings")
    parser.add_argument("--tag", type=str, default="knn", help="Name for saving results")
    parser.add_argument("--seed", type=int, default=1234, help="set seed so that the same samples are selected for training each time")
    parser.add_argument("--workers", type=int, default=12)
    parser.add_argument("--val_ratio", type=float, default=0.2, help="ratio of training data to use for validation")
    parser.add_argument("--debug", action="store_true", default=False, help="debug - if on, doesn't lock the output folder")

    ## User Config 
    parser.add_argument('--user_config', type=str, default="configs/julia_configs.yaml")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # load user_config 
    with open(args.user_config, 'r') as file:
        user_config = yaml.safe_load(file)

    os.makedirs(args.output_path, exist_ok=True)
    os.makedirs(args.embed_path, exist_ok=True)
        
    set_all_seeds(args.seed)

    models = {
        "dreamsim_ensemble": dreamsim_ensemble,
        "ensemble": ensemble,
        "dreamsim_dino_vitb16": dreamsim_dino_vitb16,
        "dreamsim_clip_vitb32": dreamsim_clip_vitb32,
        "dreamsim_open_clip_vitb32": dreamsim_open_clip_vitb32,
        "dreamsim_synclr_vitb16": dreamsim_synclr_vitb16,
        "dreamsim_dinov2_vitb14": dreamsim_dinov2_vitb14,
        "clip_vitb32": clip_vitb32,
        "dino_vitb16": dino_vitb16,
        "dinov2_vits14": dinov2_vits14,
        "dinov2_vitb14": dinov2_vitb14,
        "mae_vitb16": mae_vitb16,
        "simclrv2_r50_1x_sk1": simclrv2_rn50_1x_sk1,
        "synclr_vitb16": synclr_vitb16,
        "resnet50": resnet50,
        "resnet18": resnet18,
        'open_clip_vitb32': open_clip_vitb32,
    }

    print(f"train root: {args.train_root} \ntag {args.tag} \nmodels {args.model} \ntest sets {args.test_datasets}")

    # model_names = args.model.split(",")
    # training_config = user_config["training"]
    # if training_config['load_local']:
    #     model_dirs = training_config["model_dir"].split('|')  # use pipes to avoid splitting commas in dataset names
    #     if len(model_dirs) > 1:
    #         assert len(model_names) == 1, "Can only use one model at a time with multiple model directories"
    #         model_names = [model_names[0] for _ in model_dirs]
    #     suffixes = training_config['suffixes'].split(',')
    # else:
    #     suffixes = ['' for _ in model_names]
    #     model_dirs = [user_config["training"]['model_dir'] for _ in model_names]

    backbone = 'dino_vitb16'
    stride = '16'
    layer = 'cls'

    suffixes = ['_n100', '_n1000', '_n3000', '_n5000', '_n7000', '_n10000', '_n-1']
    suffixes.insert(0, '_n10')
    suffixes.insert(0, '_n1')

    model_names = [f'{backbone}{tag}' for tag in suffixes]
    model_dir_root = '/home/fus/repos/repalignment/dreamsim/output/new_backbones_scaling/nsamples/dino/trial2'

    a = 0.5
    model_dirs = [
        f'lora_single_{backbone}_n1_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n10_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n100_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n1000_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n3000_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n5000_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n7000_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n10000_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints',
        f'lora_single_{backbone}_n-1_cls_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_{a}_loradropout_0.0/lightning_logs/version_0/checkpoints'
    ]

    model_dirs = [os.path.join(model_dir_root, model_dir) for model_dir in model_dirs]

    # if 'dino' in backbone:
    #     transform = dino_transform
    # elif 'clip' in backbone:
    #     transform = clip_transform
    models = []

    # epoch = [7, 7, 7, 7, 5, 5, 4, 2, 2]

    for i, model_dir in enumerate(model_dirs):
        ours_model = PerceptualModel(backbone, device=device, load_dir='/home/fus/repos/repalignment/dreamsim/models',
                                     lora=True, stride=stride, feat_type='cls')

        load_dir = os.path.join(model_dir, f'epoch_7_{backbone}')
        with open(os.path.join(load_dir, 'adapter_config.json'), 'r') as f:
            dreamsim_config = json.load(f)

        lora_config = LoraConfig(**dreamsim_config)

        for extractor in ours_model.extractor_list:
            extractor.model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config).to(device)
            extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device)
            extractor.model.eval().requires_grad_(False)

        ours_model.eval().requires_grad_(False)

        models.append((ours_model, dreamsim_transform, cosine_similarity))


    # for m in models:
    #     model_curr = m[0]
    #     params = dict(model_curr.extractor_list[0].model.named_parameters())[
    #                  'base_model.model.base_model.model.model.blocks.0.attn.qkv.lora_A.weight'][:5, 0]
    #     print(id(model_curr))
    #     print(params)
    #     print()

    d_model, d_trans, d_sim = dino_vitb16({'load_local': False, 'model_dir': 'models/'})
    d_model.norm = nn.Identity()
    models.insert(0, (d_model, d_trans, d_sim))
    model_names.insert(0, backbone)
    model_dirs.insert(0, '')
    suffixes.insert(0, '_0')

    # sane_model = PerceptualModel(backbone, device=device, load_dir='/home/fus/repos/repalignment/dreamsim/models',
    #                              lora=True, stride=stride, feat_type='cls')
    # models.insert(0, (sane_model, dreamsim_transform, cosine_similarity))
    # model_names.insert(0, f'{backbone}_sane')
    # model_dirs.insert(0, '')
    # suffixes.insert(0, '')

    # models.append(dreamsim_dino_vitb16(None))
    # model_names.append(f'dreamsim_{backbone}')
    # model_dirs.append('')
    # suffixes.append('')

    # if args.model != "all":
    #     models = [models[m] for m in model_names]

    output_path = args.output_path
    os.makedirs(output_path, exist_ok=True)

    print(len(models), len(model_names), len(model_dirs), len(suffixes))
    all_results = {}
    for model_name, model_init, model_dir, suffix in zip(model_names, models, model_dirs, suffixes):
        suffix = ''
        print("Model ", model_name, suffix, model_dir)
        all_results[model_name] = {}
        model, preprocess, dist = model_init

        # try:
        #     params = list(model.parameters())[0].squeeze()
        # except:
        #     params = dict(model.extractor_list[0].model.named_parameters())['base_model.model.base_model.model.model.blocks.0.attn.qkv.lora_A.weight'][:5, 0]
        # print(params[:min(5, len(params))])

        # model.eval().requires_grad_(False)
        # training_config['model_dir'] = model_dir
        
        # model, preprocess, dist = model_init(training_config)
        # try:
        #     model_fn_curr = model.extractor_list[0].model.base_model.model.base_model.model.model
        # except:
        model_fn_curr = model.to(device)

        if "clip" in model_name and "dreamsim" not in model_name:
            print('encode_image')
            model_fn_curr = model.encode_image
        elif "mae" in model_name:
            print('lambda')
            model_fn_curr = lambda x: model(x).squeeze()
        elif "dreamsim" in model_name or "simclr" in model_name or "ensemble" in model_name:
            print('embed')
            model_fn_curr = model.embed
        else:
            print('keep')

        all_data = run(args, model_fn_curr, model_name, preprocess, dist, suffix, device)

        if all_data is not None:
            all_results[model_name + suffix] = all_data

    all_results['metadata'] = {}
    all_results['metadata']['dataset'] = args.train_root
    all_results['metadata']['classifier'] = args.classifier

    if '/' in args.tag:
        args.tag = '_'.join(args.tag.split("/"))
    with open(os.path.join(output_path, f"results_{args.tag}.json"), "w") as f:
        json.dump(all_results, f)

    # if not args.debug:
    #     pidfile.mark_job_done(args.output_path)

    print("done :)")

