import glob
import os
import numpy as np 

import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from lib.dataset.imagenetr_utils import imagenet_r_transform, imagenet_r_mask, reverse_imagenet_r_mask, imagenet_a_mask, imagenet_o_mask
from lib.dataset.objectnet_dataset import objectnet_mask
from lib.dataset.objectnet_dataset import ObjectNetDataset
from lib.dataset.imagenet_v2 import ImageNetV2
from lib.experiment import validate
from utils.data_utils import filter_joint
import wandb
import json
import copy

from torchvision.datasets import CIFAR100, ImageFolder, Caltech256, StanfordCars
from torch.utils.data import ConcatDataset

from lib.dataset import StanfordCars as SC
from lib.dataset.cub200 import CUB200

from open_clip.zero_shot_metadata import IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
from open_clip.zero_shot_classifier import build_zero_shot_classifier
from utils.data_utils import split_dataset

def load_results(args, step):
    results = {}
    if args.eval_best:
        json_path = f'{args.model_folder}/best_results.json'
    else:
        json_path = f'{args.model_folder}/results.json'
    if step is not None:
        json_path = json_path.replace('.json', f'_{step}.json')
    if os.path.exists(json_path) and not args.force and not args.collect_features:
        with open(json_path, 'r') as f:
            results = json.load(f)
        for i in range(100):
            json_path = f'{args.model_folder}/results_{i}.json'
            if os.path.exists(json_path):
                with open(json_path, 'r') as f:
                    results.update(json.load(f))
    print("Loaded results", results)
    return results


def evaluations(args, model, preprocess, text, tokenizer, criterion, old_model=None, evaluation_datasets=['imagenet_variants', 'imagenet-c', 'non_imagenet'], step=None):
#def evaluations(args, model, preprocess, text, tokenizer, criterion, old_model=None, evaluation_datasets=['imagenet_variants', 'imagenet-c'], step=None):
    """
    Evaluate the model on various datasets.
    Args:
        args: arguments
        model: model
        preprocess: preprocess function
        text: text
        tokenizer: tokenizer
        criterion: criterion
        old_model: old model for ewc
    Returns:
        outputs
    """

    # load zero-shot weight if regularization is FLYP.
    if args.regularization == 'FLYP': # Load weight
        if 'siglip' in args.model:
            from lib.siglip2.zero_shot_classifier import build_zero_shot_classifier as build_siglip2_zero_shot_classifier
            zero_shot_weights = build_siglip2_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device='cuda')
        else:
            zero_shot_weights = build_zero_shot_classifier(model, tokenizer, IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, 50, device='cuda')
        model.zero_shot_weights = zero_shot_weights

    
    results = load_results(args, step)
    output = {}
    if step is not None:
        output['step'] = step
    
    if 'non_imagenet' in evaluation_datasets and args.dataset in ['cub200', 'cars', 'caltech256']:
        _output = evaluate_non_imagenet(args, model, preprocess, text, tokenizer, criterion, old_model=old_model, results=results)
        output.update(_output)

    if 'imagenet_variants' in evaluation_datasets:
        _output = evaluate_imagenet_variants(args, model, preprocess, text, tokenizer, criterion, old_model=old_model, results=results)
        output.update(_output)
    if not args.no_split and args.dataset != 'imagenet-c':
        _output = evaluate_validation_set(args, model, preprocess, text, tokenizer, criterion, old_model=old_model, results=results)
        output.update(_output)
    if args.eval_best:
        json_path = f'{args.model_folder}/best_results.json'
    else:
        json_path = f'{args.model_folder}/results.json'
    if step is not None:
        json_path = json_path.replace('.json', f'_{step}.json')
    if os.path.exists(json_path) and not args.force and not args.collect_features:
        try:
            output.update(json.load(open(json_path, 'r')))
        except:
            pass
    with open(json_path, 'w') as f:
        json.dump(output, f)

    if 'imagenet-c' in evaluation_datasets: # and False:
        _output = evaluate_imagenet_c(args, model, preprocess, text, tokenizer, criterion, results=results, json_path=json_path)
        output.update(_output)
    if not args.no_split and args.dataset == 'imagenet-c':
        _output = evaluate_imagenet_c_val(args, model, preprocess, text, tokenizer, criterion, results=results, json_path=json_path)
        output.update(_output)
    if args.use_wandb:
        wandb.log(output)
    if os.path.exists(json_path):
        output.update(json.load(open(json_path, 'r')))
    with open(json_path, 'w') as f:
        json.dump(output, f)
    return output


def evaluate_non_imagenet(args, model, preprocess, texts, tokenizer, criterion, old_model=None, results={}):
    dataset_names = ['cub200/train', 'cub200/test', 'cars/train', 'cars/test', 'caltech256'][2:]
    acc1s = []

    for dataset_name in dataset_names:
        if dataset_name in results and not args.collect_features: 
            acc1s.append(results[dataset_name])
            continue
        print("Evaluating", dataset_name)
        path = os.path.join(args.root, dataset_name)

        if dataset_name == 'caltech256':
            dataset = Caltech256(root=f"{args.root}/caltech256", transform=preprocess)
            classes = [e.replace('-101','').split('.')[1] for e in dataset.categories]
        if 'cars' in dataset_name:
            dataset = SC(root=args.root, split=dataset_name.split('/')[1], transform=preprocess)
            classes = dataset.classes
        elif 'cub200' in dataset_name:
            dataset = CUB200(root=args.root, split=dataset_name.split('/')[1], transform=preprocess)
            classes = dataset.classes

        caltech_zero_shot_weights = build_zero_shot_classifier(model, tokenizer, classes, OPENAI_IMAGENET_TEMPLATES, 50, device='cuda')
        original_zero_shot_weights = copy.deepcopy(model.zero_shot_weights)
        model.zero_shot_weights = caltech_zero_shot_weights

        data_loader = DataLoader(dataset,
                                batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        prefix = dataset.root.split('/')[-1]
        acc1, loss, logit_info, features = validate(data_loader, texts, model, tokenizer, criterion, args, prefix=prefix, return_attention=False)
        acc1s.append(acc1)
        print(dataset_name, acc1)
        if args.collect_logits:
            save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_logits.pth")
            torch.save(logit_info, save_path)
            print(f'saved logits to {save_path}')
        if args.collect_features:
            save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_features.pth")
            torch.save(features, save_path)
            print(f'saved features to {save_path}')
        # save features
        model.zero_shot_weights = original_zero_shot_weights

    print(dataset_names)
    print(acc1s)
    print(np.mean(acc1s))
    outputs = {name: acc for name, acc in zip(dataset_names, acc1s)}
    return outputs





def evaluate_imagenet_variants(args, model, preprocess, texts, tokenizer, criterion, old_model=None, results={}):
    """
    Evaluate the model on various imagenet variants.
    Args:
        args: arguments
        model: model
        preprocess: preprocess function
        texts: list of texts
        tokenizer: tokenizer
        criterion: criterion
        old_model: old model for ewc
    Returns:
        mean accuracy
    """
    dataset_names = ['imagenet/val',  'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'objectnet-1.0', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-v2', 'objectnet-v2']
    masks = [None, imagenet_a_mask, imagenet_r_mask, None, None, None, None, None, objectnet_mask]
#    dataset_names = ['imagenet/val',  'imagenet-v2']
#    masks = [None, None]
    acc1s = []
    for dataset_name, mask in zip(dataset_names, masks):
        if dataset_name in results and not args.collect_features: 
            acc1s.append(results[dataset_name])
            continue
        print("Evaluating", dataset_name)
        path = os.path.join(args.root, dataset_name)
        if dataset_name == 'objectnet-1.0':
            dataset = ObjectNetDataset(root=path, transform=preprocess)
        elif dataset_name == 'objectnet-v2':
            dataset = ObjectNetDataset(root=os.path.join(args.root, 'objectnet-1.0'), transform=preprocess, reindex=True)
        elif dataset_name == 'imagenet-v2':
            dataset = ImageNetV2(path, transform=preprocess)
        else:
            dataset = ImageFolder(path, transform=preprocess)

        if args.joint_only:
            dataset = filter_joint(dataset)

        data_loader = DataLoader(dataset,
                                batch_size=args.batch_size, pin_memory=True,
                                num_workers=args.num_workers, shuffle=False)
        prefix = dataset.root.split('/')[-1]
        acc1, loss, logit_info, features = validate(data_loader, texts, model, tokenizer, criterion, args, prefix=prefix, return_attention=False, mask=mask) #, old_model=old_model)
        acc1s.append(acc1)
        print(dataset_name, acc1)
        if args.collect_logits:
            save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_logits.pth")
            torch.save(logit_info, save_path)
            print(f'saved logits to {save_path}')
        if args.collect_features:
            save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_features.pth")
            torch.save(features, save_path)
            print(f'saved features to {save_path}')
        # save features


    print(dataset_names)
    print(acc1s)
    print(np.mean(acc1s))
    outputs = {name: acc for name, acc in zip(dataset_names, acc1s)}
    return outputs


def evaluate_imagenet_c(args, model, preprocess, texts, tokenizer, criterion, results={}, json_path=None):
    paths = glob.glob(os.path.join(args.root, 'imagenet-c/*'))
    imagenet_c = [
            'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
    paths = [os.path.join(args.root, 'imagenet-c', name) for name in imagenet_c]
    acc1s = []
    outputs = {}
    json_output = {}
    if json_path is not None and os.path.exists(json_path):
        json_output = json.load(open(json_path, 'r'))
    for path in paths:
        if not os.path.isdir(path):
            continue
        name = path.replace(args.root, '')
        if name in results and not args.collect_features:
            acc1s.append(results[name])
            continue
        _acc = []
        for severity in range(1,6):
            _path = os.path.join(path, str(severity))
            name = _path.replace(args.root, '')
            if name in json_output and not args.collect_features:
                outputs[name] = json_output[name]
                continue
            dataset = ImageFolder(_path, transform=preprocess)
            if args.joint_only:
                dataset = filter_joint(dataset)

            data_loader = DataLoader(dataset,
                                    batch_size=args.batch_size*2, pin_memory=True,
                                    num_workers=args.num_workers, shuffle=False)
            prefix = dataset.root.split('/')[-1]
            print(prefix)
            acc1, loss, logit_info, features = validate(data_loader, texts, model, tokenizer, criterion, args, prefix=prefix, return_attention=False, mask=None)
            print(_path, acc1)
            outputs[name] = acc1
            json_output[name] = acc1
            _acc.append(acc1)
            if args.collect_features:
                dataset_name = _path.split('/')[-2] + '_' + _path.split('/')[-1]
                save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name}_features.pth")
                torch.save(features, save_path)
                print(f'saved features to {save_path}')
            if args.collect_logits:
                dataset_name = _path.split('/')[-2] + '_' + _path.split('/')[-1]
                save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name}_logits.pth")
                torch.save(logit_info, save_path)
                print(f'saved logits to {save_path}')
        
        with open(json_path, 'w') as f:
            json.dump(json_output, f)

        acc1 = np.mean(_acc)
        acc1s.append(acc1)
        print(path, acc1)
    print(acc1s)
    print(np.mean(acc1s))
    outputs.update({name.replace(args.root, ''): acc for name, acc in zip(paths, acc1s)})
    return outputs




def evaluate_validation_set(args, model, preprocess, texts, tokenizer, criterion, old_model=None, results={}):
    """
    Evaluate the model on various imagenet variants.
    Args:
        args: arguments
        model: model
        preprocess: preprocess function
        texts: list of texts
        tokenizer: tokenizer
        criterion: criterion
        old_model: old model for ewc
    Returns:
        mean accuracy
    """
    dataset_names = ['imagenet/val',  'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'objectnet-1.0', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-v2', 'objectnet-v2']
    masks = [None, imagenet_a_mask, imagenet_r_mask, None, None, None, None, None, objectnet_mask]
#    dataset_names = ['imagenet/val',  'imagenet-v2']
#    masks = [None, None]
    acc1s = []
    dataset_name = args.dataset
    mask = masks[dataset_names.index(args.dataset)]

    print("Evaluating", dataset_name)
    path = os.path.join(args.root, dataset_name)
    if dataset_name == 'objectnet-1.0':
        dataset = ObjectNetDataset(root=path, transform=preprocess)
    elif dataset_name == 'objectnet-v2':
        dataset = ObjectNetDataset(root=os.path.join(args.root, 'objectnet-1.0'), transform=preprocess, reindex=True)
    elif dataset_name == 'imagenet-v2':
        dataset = ImageNetV2(path, transform=preprocess)
    else:
        dataset = ImageFolder(path, transform=preprocess)

    if args.joint_only:
        dataset = filter_joint(dataset)
    
    dataset = split_dataset(dataset, 0.2, args.seed)[1] # split dataset to get validation set

    data_loader = DataLoader(dataset,
                            batch_size=args.batch_size, pin_memory=True,
                            num_workers=args.num_workers, shuffle=False)
    prefix = dataset.root.split('/')[-1]
    acc1, loss, logit_info, features = validate(data_loader, texts, model, tokenizer, criterion, args, prefix=prefix, return_attention=False, mask=mask) #, old_model=old_model)
    acc1s.append(acc1)
    print(dataset_name, acc1)
    if args.collect_logits:
        save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_logits.pth")
        torch.save(logit_info, save_path)
        print(f'saved logits to {save_path}')
    if args.collect_features:
        save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name.replace('/','_')}_features.pth")
        torch.save(features, save_path)
        print(f'saved features to {save_path}')
    # save features


    print(dataset_names)
    print(acc1s)
    print(np.mean(acc1s))
    outputs = {name + '/val': acc for name, acc in zip([dataset_name], acc1s)}
    return outputs


def evaluate_imagenet_c_val(args, model, preprocess, texts, tokenizer, criterion, results={}, json_path=None):
    paths = glob.glob(os.path.join(args.root, 'imagenet-c/*'))
    imagenet_c = [
            'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
    paths = [os.path.join(args.root, 'imagenet-c', name) for name in imagenet_c]
    acc1s = []
    outputs = {}
    json_output = {}
    if json_path is not None and os.path.exists(json_path):
        json_output = json.load(open(json_path, 'r'))
    for path in paths:
        if not os.path.isdir(path):
            continue
        name = path.replace(args.root, '')
#        if name + '/val' in results and not args.collect_features:
#            acc1s.append(results[name + '/val'])
#            continue
        _acc = []
        for severity in range(1,6):
            _path = os.path.join(path, str(severity))
            name = _path.replace(args.root, '')
#            if name + '/val' in json_output and not args.collect_features:
#                outputs[name + '/val'] = json_output[name + '/val']
#                continue
            dataset = ImageFolder(_path, transform=preprocess)
            if args.joint_only:
                dataset = filter_joint(dataset)

            dataset = split_dataset(dataset, 0.2, 0)[1]

            data_loader = DataLoader(dataset,
                                    batch_size=args.batch_size*2, pin_memory=True,
                                    num_workers=args.num_workers, shuffle=False)
            prefix = dataset.root.split('/')[-1]
            print(prefix)
            acc1, loss, logit_info, features = validate(data_loader, texts, model, tokenizer, criterion, args, prefix=prefix, return_attention=False, mask=None)
            print(_path, acc1)
            outputs[name + '/val'] = acc1
            json_output[name + '/val'] = acc1
            _acc.append(acc1)
            if args.collect_features:
                dataset_name = _path.split('/')[-2] + '_' + _path.split('/')[-1]
                save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name}_features.pth")
                torch.save(features, save_path)
                print(f'saved features to {save_path}')
            if args.collect_logits:
                dataset_name = _path.split('/')[-2] + '_' + _path.split('/')[-1]
                save_path = os.path.join(args.model_dir, args.filename, f"{dataset_name}_logits.pth")
                torch.save(logit_info, save_path)
                print(f'saved logits to {save_path}')
        
        with open(json_path, 'w') as f:
            json.dump(json_output, f)

        acc1 = np.mean(_acc)
        acc1s.append(acc1)
        print(path, acc1)
    print(acc1s)
    print(np.mean(acc1s))
    outputs.update({name.replace(args.root, '') + '/val': acc for name, acc in zip(paths, acc1s)})
    return outputs


