import copy
import random
from re import X
import operator
from numbers import Number
from collections import OrderedDict
import torch
import numpy as np
from utils.helper import get_dataloader_fish


# Functions

def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            L[i, j] = cov_function(items[i],items[j])             #cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f

def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X

        

def get_domain(iteration,run,args,model,kwargs,cuda,batch_left):
    if iteration==0: 
        if args.dataset_name in ['rot_mnist', 'fashion_mnist']:
            domains=random.sample(range(15,76),5)
        elif args.dataset_name=='rot_mnist_spur':
            domains=random.sample(range(0,78),6)
        domains.sort()
    else:
        if args.dataset_name in ['rot_mnist', 'fashion_mnist']:
            total_domains=list(range(15,76))
        elif args.dataset_name=='rot_mnist_spur':
            total_domains=list(range(0,78))
        phi = copy.deepcopy(model)
        location_dom=[]##记录每个domain的location，用于生成相似度矩阵L

        for domain in total_domains:
            data = get_dataloader_fish(args, run, domain, data_case='train', kwargs=kwargs)#####按domain加载数据
            mean_global=[]
            for batch_idx, (x_e, y_e) in enumerate(data):############把数据输入到模型里
            #(x_e, y_e)=data[19-batch_left[domain-15]]
                if batch_idx==9-batch_left[domain-15]:
                    #print('batch_idx: '+str(batch_idx))
                    with torch.no_grad():
                        x_e= x_e.to(cuda)
                        # print('x_e.shape')
                        # print(x_e.shape)

                    #Forward Pass
                        
                        out= phi.enc(x_e)
                        # print('out.shape')
                        # print(out.shape)                         
                        mean_v_ten=torch.mean(out,0)   ######## 
                        mean_v_num=mean_v_ten.cpu().numpy()
                    break
            #         mean_global.append(mean_v_num)
            #         # print('out : ')
            #         # print(mean_v_num)
            # #print('final mean :  ')
            # location=np.mean(mean_global,0) ##当前domain的location
            #print(location)
            location_dom.append(mean_v_num)
        
        L=build_similary_matrix(get_similar_cov(4),location_dom)
        ##选取差异最大的domain
        if args.dataset_name in ['rot_mnist', 'fashion_mnist']:
            X=sample_k(location_dom,L,5)
        elif args.dataset_name=='rot_mnist_spur':
            X=sample_k(location_dom,L,6)

        domains=[]

        for i in range(0,len(X)):
            if X[i]==True:
                if args.dataset_name in ['rot_mnist', 'fashion_mnist']:
                    domains.append((i+15))
                elif args.dataset_name=='rot_mnist_spur':
                    domains.append(i)
    return domains




class ParamDict(OrderedDict):
    """A dictionary where the values are Tensors, meant to represent weights of
    a model. This subclass lets you perform arithmetic on weights directly."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, *kwargs)

    def _prototype(self, other, op):
        if isinstance(other, Number):
            return ParamDict({k: op(v, other) for k, v in self.items()})
        elif isinstance(other, dict):
            return ParamDict({k: op(self[k], other[k]) for k in self})
        else:
            raise NotImplementedError

    def __add__(self, other):
        return self._prototype(other, operator.add)

    def __rmul__(self, other):
        return self._prototype(other, operator.mul)

    __mul__ = __rmul__

    def __neg__(self):
        return ParamDict({k: -v for k, v in self.items()})

    def __rsub__(self, other):
        # a- b := a + (-b)
        return self.__add__(other.__neg__())

    __sub__ = __rsub__

    def __truediv__(self, other):
        return self._prototype(other, operator.truediv)


def fish_step(meta_weights, inner_weights, meta_lr):
    meta_weights, weights = ParamDict(meta_weights), ParamDict(inner_weights)
    meta_weights += meta_lr * sum([weights - meta_weights], 0 * meta_weights)
    return meta_weights