import argparse
import logging
import os
import random

import numpy as np
import timm
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from tqdm import tqdm
from torch import hub


def prepare_experiment(args):
    torch.hub.set_dir('./cache')  # This is the dir to store the downloaded pretrained models
    save_path = os.path.join(args.result_dir, args.dataset, args.identifier)
    os.makedirs(save_path, exist_ok=True)
    logging.basicConfig(filename=os.path.join(save_path, 'logs.txt'),
                        filemode='w',
                        level=logging.INFO)
    logging.info(args)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    set_seed(args.seed)
    return logging, device


def get_general_args():
    p = argparse.ArgumentParser()

    # General settings
    p.add_argument('--seed', type=int, default=588)
    p.add_argument('--dataset', required=True,
                   choices=["cifar10", "cifar100",
                            "dtd", "flowers102", "ucf101", "food101",
                            "gtsrb", "svhn", "waterbirds", "eurosat",
                            "oxfordpets", "stanfordcars", "sun397", "pcam", "oxfordflowers",
                            "caltech101", "nabirds", "cub200", "stanforddogs",
                            "vtab1k-caltech101", "vtab1k-clevr_count", "vtab1k-diabetic_retinopathy",
                            "vtab1k-dsprites_loc", "vtab1k-dtd", "vtab1k-kitti", "vtab1k-oxford_iiit_pet",
                            "vtab1k-resisc45", "vtab1k-smallnorb_ele", "vtab1k-svhn", "vtab1k-cifar",
                            "vtab1k-clevr_dist", "vtab1k-dmlab", "vtab1k-dsprites_ori", "vtab1k-eurosat",
                            "vtab1k-oxford_flowers102", "vtab1k-patch_camelyon", "vtab1k-smallnorb_azi",
                            "vtab1k-sun397"])
    p.add_argument('--batch-size', type=int, default=128)
    p.add_argument('--data-path', type=str, default="../data/vp_data/")
    p.add_argument('--epoch', type=int, default=50)
    p.add_argument('--test-start', type=int, default=0)
    p.add_argument('--test-interval', type=int, default=1)
    p.add_argument('--result-dir', default="results", type=str)
    p.add_argument('--identifier', '--id', type=str, required=True,
                   help='To identify the folder name to save the results')
    return p


def calculate_trainable_param(optimizer_list):
    count = 0
    for optimizer in optimizer_list:
        subtotal = 0
        for group in optimizer.param_groups:
            for param in group['params']:
                if param.requires_grad:
                    count += param.numel()
                    subtotal += param.numel()
    return count


def get_vit_network(network_type, device):
    if network_type == "vit_tiny":
        network = timm.create_model("vit_tiny_patch16_224", pretrained=True).to(device)
        dim = 192
    elif network_type == "vit_small":
        network = timm.create_model("vit_small_patch16_224", pretrained=True).to(device)
        dim = 384
    elif network_type == "vit_base":
        network = timm.create_model("vit_base_patch16_224", pretrained=True).to(device)
        dim = 768
    elif network_type == "vit_large":
        network = timm.create_model("vit_large_patch16_224", pretrained=True).to(device)
        dim = 1024
    else:
        raise NotImplementedError(f"{network_type} is not supported")

    return network, dim


def get_cnn_network(network_type, device):
    # Network
    if network_type == "resnet18":
        from torchvision.models import resnet18, ResNet18_Weights
        network = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to(device)
    elif network_type == "resnet50":
        from torchvision.models import resnet50, ResNet50_Weights
        network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
    elif network_type == "resnet101":
        from torchvision.models import resnet101, ResNet101_Weights
        network = resnet101(weights=ResNet101_Weights.IMAGENET1K_V1).to(device)
    elif network_type == "instagram":
        network = hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl').to(device)
    else:
        raise NotImplementedError(f"{network_type} is not supported")

    return network


def train_network(network, train_loader, scaler, optimizer_list, scheduler_list, device, epoch, description="Training"):
    total_num = 0
    true_num = 0
    loss_sum = 0
    pbar = tqdm(train_loader, total=len(train_loader), desc=f"{description}: epoch {epoch}")
    for x, y in pbar:

        if x.get_device() == -1:
            x, y = x.to(device), y.to(device)
        for optimizer in optimizer_list:
            optimizer.zero_grad()

        with autocast():
            fx = network(x)
            loss = F.cross_entropy(fx, y, reduction='mean')

        scaler.scale(loss).backward()
        for optimizer in optimizer_list:
            scaler.step(optimizer)
            scaler.update()
        total_num += y.size(0)
        true_num += torch.argmax(fx, 1).eq(y).float().sum().item()
        loss_sum += loss.item() * fx.size(0)
        pbar.set_postfix_str(f"Acc {100 * true_num / total_num:.2f}%")
    for scheduler in scheduler_list:
        scheduler.step()
    train_acc = true_num / total_num

    return train_acc


def eval_network(network, test_loader, device, epoch):
    network.eval()
    total_num = 0
    true_num = 0
    pbar = tqdm(test_loader, total=len(test_loader), desc=f"Epoch {epoch} Testing", ncols=100)
    ys = []
    acc = 0.0
    for x, y in pbar:
        if x.get_device() == -1:
            x, y = x.to(device), y.to(device)
        ys.append(y)
        with torch.no_grad():
            fx = network(x)
        total_num += y.size(0)
        true_num += torch.argmax(fx, 1).eq(y).float().sum().item()
        acc = true_num / total_num
        pbar.set_postfix_str(f"Acc {100 * acc:.2f}%")

    return acc


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def override_func(inst, func, func_name):
    bound_method = func.__get__(inst, inst.__class__)
    setattr(inst, func_name, bound_method)


def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()


def gen_folder_name(args):
    def get_attr(inst, arg):
        value = getattr(inst, arg)
        if isinstance(value, float):
            return f"{value:.4f}"
        else:
            return value

    folder_name = ''
    for arg in vars(args):
        folder_name += f'{arg}-{get_attr(args, arg)}~'
    return folder_name[:-1]


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


DEFAULT_TEMPLATE = "This is a photo of a {}."

ENSEMBLE_TEMPLATES = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]
