import torch

import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime
import torch.optim as optim
import math

from utils.misc import *
from utils.misc_cifar import *
from utils.models import *


######################################################################################################################################
# functions
def save_bn_params(_model, _optimizer):
    _initial_params = {}
    for nm, m in _model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:
                    _initial_params["%s_%s" % (nm, np)] = p.detach().clone().cpu()

    _optimizer_state = copy.deepcopy(_optimizer.state_dict())
    return _initial_params, _optimizer_state


def load_bn_params(_model, _optimizer, _initial_params, _optimizer_state):
    with torch.no_grad():
        for nm, m in _model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                for np, p in m.named_parameters():
                    if np in ['weight', 'bias']:  # weight is scale, bias is shift
                        p.data = _initial_params["%s_%s" % (nm, np)].to(p.device)

    _optimizer.load_state_dict(_optimizer_state)


def collect_params2(model, ft_layers):
    params = []
    names = []
    for nm, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")
    return params, names

def mean_lst(data):
    """Return the sample arithmetic mean of data."""
    n = len(data)
    if n < 1:
        raise ValueError('mean requires at least one data point')
    return sum(data) / n  # in Python 2 use sum(data)/float(n)


def _ss(data):
    """Return sum of square deviations of sequence data."""
    c = mean_lst(data)
    ss = sum((x - c) ** 2 for x in data)
    return ss


def stddev_lst(data, ddof=0):
    """Calculates the population standard deviation
    by default; specify ddof=1 to compute the sample
    standard deviation."""
    n = len(data)
    if n < 2:
        raise ValueError('variance requires at least two data points')
    ss = _ss(data)
    pvar = ss / (n - ddof)
    return pvar ** 0.5


def dirichlet_indices(x, y, trained_clf, num_classes, dirichlet_numchunks=250, non_iid_ness=1., batch_size=200,
                      min_size_threshold=0):
    new_indices = []
    min_size = -1
    N = x.size(0)
    min_size_threshold = 0  # hyperparameter.
    while (
            min_size < min_size_threshold
    ):  # prevent any chunk having too less data
        idx_batch = [[] for _ in range(dirichlet_numchunks)]
        idx_batch_cls = [
            [] for _ in range(dirichlet_numchunks)
        ]  # contains data per each class
        for k in range(num_classes):
            targets_np = y.detach().cpu().numpy()  # targets_np = torch.Tensor(y).numpy()
            idx_k = np.where(targets_np == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(
                np.repeat(non_iid_ness, dirichlet_numchunks)
            )

            # balance
            proportions = np.array(
                [
                    p * (len(idx_j) < N / dirichlet_numchunks)
                    for p, idx_j in zip(proportions, idx_batch)
                ]
            )
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))
            ]
            min_size = min([len(idx_j) for idx_j in idx_batch])

            # store class-wise data
            for idx_j, idx in zip(idx_batch_cls, np.split(idx_k, proportions)):
                idx_j.append(idx)

    sequence_stats = []
    # create temporally correlated toy dataset by shuffling classes
    for chunk in idx_batch_cls:
        cls_seq = list(range(num_classes))
        np.random.shuffle(cls_seq)
        for cls in cls_seq:
            idx = chunk[cls]
            new_indices.extend(idx)
            sequence_stats.extend(list(np.repeat(cls, len(idx))))

    num_samples = len(new_indices)
    new_indices = new_indices[:num_samples]

    return new_indices

class MLP_dart(nn.Module):
    def __init__(self, num_classes, hiddendim):
        super().__init__()
        self.num_classes = num_classes
        self.use_bias = use_bias
        self.relu = nn.ReLU(inplace=False)
        self.fc1 = nn.Linear(num_classes+1, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim, num_classes * num_classes + num_classes, bias=True)

    def forward(self, py, pred_dev):
        g_phi_input = torch.cat((py, pred_dev.view(1)))
        out = self.fc1(g_phi_input)
        out = self.relu(out)
        out = self.fc2(out)

        W, b = out[:self.num_classes * self.num_classes].view(self.num_classes, self.num_classes), out[self.num_classes * self.num_classes:].view(1, self.num_classes)
        return W, b

class MLP_dart_split(nn.Module):
    def __init__(self, num_classes, hiddendim):
        super().__init__()
        self.num_classes = num_classes
        self.use_bias = use_bias
        self.relu = nn.ReLU(inplace=False)
        self.fc1 = nn.Linear(num_classes, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim, num_classes * num_classes + num_classes, bias=True)
        self.fc3 = nn.Linear(1, 1, bias=True)

    def forward(self, py, pred_dev):
        out = self.fc1(py)
        out = self.relu(out)
        out = self.fc2(out)

        inp = torch.sigmoid(self.fc3(pred_dev.view(1)))

        W, b = out[:self.num_classes * self.num_classes].view(self.num_classes, self.num_classes), out[self.num_classes * self.num_classes:].view(1, self.num_classes)
        return W, b, inp



def get_img_num_per_cls(num_examples, cls_num, imb_factor, imb_type='exp'):
    img_max = num_examples / cls_num
    img_num_per_cls = []
    if imb_type == 'exp':
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    if imb_type == 'inv_exp':
        for cls_idx in range(cls_num):
            num = img_max * (imb_factor ** ((cls_num - 1 - cls_idx) / (cls_num - 1.0)))
            img_num_per_cls.append(int(num))
    elif imb_type == 'step':
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max))
        for cls_idx in range(cls_num // 2):
            img_num_per_cls.append(int(img_max * imb_factor))
    else:
        img_num_per_cls.extend([int(img_max)] * cls_num)
    return img_num_per_cls


def make_LT_indices(x, y, rho, imb_type, args):
    path = "./eval_results/idx/%s/lt/rho%.2f/%s_all_seed%d" % (args.benchmark, rho, imb_type, args.seed)
    try:
        indices = torch.load(path)
    except:
        indices = []
        num_cls = get_img_num_per_cls(x.size(0), args.num_classes, rho, imb_type)
        for c in range(args.num_classes):
            y_c_idx = (y == c).nonzero().view(-1)
            try:
                idx_c = torch.load(
                    "./eval_results/idx/%s/lt/rho%.2f/%s_%d_seed%d" % (args.benchmark, rho, imb_type, c, args.seed))
            except:
                if not os.path.exists("./eval_results/idx/%s/lt/rho%.2f" % (args.benchmark, rho)):
                    os.makedirs("./eval_results/idx/%s/lt/rho%.2f" % (args.benchmark, rho))

                idx_c = torch.randperm(x.size(0) // args.num_classes)[:num_cls[c]]
                torch.save(idx_c, "./eval_results/idx/%s/lt/rho%.2f/%s_%d_seed%d" % (
                args.benchmark, rho, imb_type, c, args.seed))
            indices.append(y_c_idx[idx_c])
        indices = torch.cat(indices)
        perm = torch.randperm(indices.size(0))
        indices = indices[perm]

        torch.save(indices, path)

    return indices

def make_balanced_intermediate_dataset(x_train, y_train, num_classes, num_ex_per_class=None):
    # make the balanced intermediate dataset
    py_train = torch.zeros(num_classes)
    for c in range(num_classes):
        py_train[c] = (y_train == c).int().sum()
    # print ("utils_intermediate.py//line95", py_train)

    if py_train.max() == py_train.min():
        print ('it is already balanced')
        return x_train, y_train
    else:
        x_int_bal, y_int_bal = [], []
        if num_ex_per_class is None:
            num_ex_per_class = py_train.max()
        for c in range(num_classes):
            idx_c = y_train == c
            temp1 = int(num_ex_per_class // py_train[c])
            temp2 = int(num_ex_per_class - temp1 * py_train[c])

            x_int_bal.append(x_train[idx_c].repeat(temp1, 1, 1, 1))
            y_int_bal.append(y_train[idx_c].repeat(temp1))

            temp3 = torch.randperm(int(idx_c.float().sum()))[:temp2]

            x_int_bal.append(x_train[idx_c][temp3])
            y_int_bal.append(y_train[idx_c][temp3])

        x_int_bal, y_int_bal = torch.cat(x_int_bal, 0), torch.cat(y_int_bal, 0)
        y_int_bal = y_int_bal.long()
        py_train = torch.zeros(num_classes)
        for c in range(num_classes):
            py_train[c] = (y_int_bal == c).int().sum()

        return x_int_bal, y_int_bal
