import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import random
from collections.abc import Iterable
from copy import deepcopy
from torch.utils.data import TensorDataset
from torch.utils.data import SubsetRandomSampler, DataLoader, Subset
from dset import load_data
import torchvision.datasets as datasets
import matplotlib.pyplot as plt


# def alpha_score(fea_s):

def mean_var(feature):
    # feature of shape [N,C,H,W]
    num_features = feature.shape[1]
    input = feature.permute(0, 2, 3, 1).contiguous().view(-1, num_features)
    mean = torch.mean(input, dim=0).cpu().detach().numpy()
    var = torch.var(input, dim=0).cpu().detach().numpy()

    return mean, var




def mask_layer(layer_list, pre_model):
    mask = {}
    cnt = 0
    num_cnt = 0
    for name, para in pre_model.named_parameters():
        mask[name] = torch.zeros_like(para)
        # only modify conv layer
        if len(para.shape) == 4: # used for resnet
        # if len(para.shape) > 1 and para.shape[-1] > 1:
        # if len(para.shape) > 1 : # used for mobilenet
            tmp = torch.zeros_like(para)
            para_re = torch.zeros_like(para)
            f = layer_list[cnt]
            tmp[f] = 1
            mask[name] = tmp
            num_cnt += mask[name].sum()

            cnt += 1
            # break
        # else:
        #     mask[name] = torch.zeros_like(para)
    print('mask num',num_cnt)
    return mask





def query_eval(net, device, testloader):
    # Get the performance of updated net on query set
    # net.to(device)
    net.eval()
    correct = 0
    total = 0

    criterion = nn.CrossEntropyLoss()

    for i, (inputs, targets) in enumerate(testloader):
        # print('in',inputs.shape)
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs,0)
        loss = criterion(outputs, targets)
        grad_qry = torch.autograd.grad(loss, net.parameters())
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        break
    # acc = correct / total
    return correct, loss, grad_qry


def inference_alpha(net, victim_state_dict, device, testloader):
    net.load_state_dict(victim_state_dict)
    net.to(device)
    net.eval()
    # correct = 0
    # total = 0

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mean_l, var_l = net(inputs, 1)
            break


    return mean_l, var_l



def inference(net, device, testloader):
    net.to(device)
    net.eval()
    correct = 0
    total = 0

    # criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(testloader):
            # print('in',inputs.shape)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs, 0)
            # loss = criterion(outputs, targets)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = correct / total
    return acc

def update(net, fast_weights):
    # net.to(device)
    # dic = {}
    tmp_dict = deepcopy(net.state_dict())
    for (name, _), w in zip(net.named_parameters(), fast_weights):

        tmp_dict[name] = w
    # tmp_dict.update(dic)
    return tmp_dict

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def sparsity(mask,model,B,t):
    sum = 0
    change = 0
    weights = []
    #
    for name, para in model.named_parameters():

        sum += mask[name].sum()
        # delta = [para.flatten()[i]-B for i in range(len(para.flatten())) if mask[name].flatten()[i] is True]
        change += abs(para.flatten()-B>t).sum()
        weights = None
        #weights.extend(para[mask[name].long()].detach().cpu().flatten()-B)

    return weights, change, sum, change/sum



def balanced_loader(dset,num_cls,num_samples, batch):
    labels = torch.tensor(dset.targets)
    indices = [torch.where(labels == i)[0] for i in range(num_cls)]

    # Sample an equal number of examples for each class
    # num_samples = batch//num_cls + 1
    # print(num_samples)
    # print(random.shuffle(indices[0]))
    indices = [torch.tensor(random.sample(indices[i].tolist(), num_samples)) for i in range(num_cls)]
    indices = torch.stack(indices).view(-1)
    # print(len(indices))
    # Use the SubsetRandomSampler to sample from the dataset
    sampler = SubsetRandomSampler(indices)
    dataloader = DataLoader(dset, batch_size=batch, sampler=sampler)
    return dataloader

def sub_dataset(dset, cls_idx, num_sub_cls,num_samples, trainset_flag):
    labels = torch.tensor(dset.targets)
    # cls = np.random.choice(num_cls,num_sub_cls).tolist()
    indices = [torch.where(labels == i)[0] for i in cls_idx]


    # Sample an equal number of examples for each class
    # num_samples = batch//num_cls + 1
    # print(num_samples)
    # print(random.shuffle(indices[0]))
    if trainset_flag == True:
        indices = [torch.tensor(random.sample(indices[i].tolist(), num_samples)) for i in range(num_sub_cls)]
    indices = torch.stack(indices).view(-1)
    # print(len(indices))
    sub_set = Subset(dset, indices)
    return sub_set

def alpha_select(domain_list, net, victim_state_dict, device):
    domain_dis =[]
    len_d = len(domain_list)
    for i in range(len_d):
        mean_l, var_l = inference_alpha(net, victim_state_dict, device, domain_list[i])
        stat_l = []
        for k in range(len(mean_l)):
            stat = [m/np.sqrt(v) for m, v in zip(mean_l[k], var_l[k])]
            stat_l.append(stat)
        if i == 0:
            stat_src = stat_l
        else:
            d_l = []
            for k in range(len(mean_l)):
                d = [np.abs(m-n) for m,n in zip(stat_src[k],stat_l[k])]
                d_l.append(d)
            domain_dis.append(d_l)
    filter_l = []
    for layer in range(len(mean_l)):
        # print('layer filter num', len(domain_dis[0][layer]))
        tmp = np.array(domain_dis[0][layer])
        for i in range(1,len(domain_dis)):
            tmp += np.array(domain_dis[i][layer])
        idx = np.argmin(tmp)
        filter_l.append(idx)
    return filter_l



def auxilary(num_domain, num_cls, num_samples, batch, name):
    spt_list = []
    qry_list = []
    for i in range(num_domain):

        trainset, testset = load_data(name)
        # transform_aug = transforms.Compose([transforms.Resize((32, 32)),
        #                                     transforms.RandAugment(),
        #                                     transforms.ToTensor(),
        #                                     transforms.Normalize((0.5,), (0.5,))])
        # trainset = datasets.USPS(root='../dataset/', train=True, download=True, transform=transform_aug)
        #
        # testset = datasets.USPS(root='../dataset/', train=False, download=True, transform=transform_aug)

        train_loader = balanced_loader(trainset, num_cls, num_samples,batch)

        test_loader = balanced_loader(testset, num_cls, batch//num_cls + 1, batch)

        spt_list.append(train_loader)
        qry_list.append(test_loader)
    return spt_list,qry_list




def set_freeze_by_names(model, layer_names, freeze=True):
    if not isinstance(layer_names, Iterable):
        layer_names = [layer_names]
    for name, child in model.named_children():
        if name not in layer_names:
            continue
        for param in child.parameters():
            param.requires_grad = not freeze
