import torch
import torch.nn as nn

import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime

class MLP(nn.Module):
    def __init__(self, num_classes, hiddendim):
        super().__init__() 
        self.num_classes = num_classes
        self.fc1 = nn.Linear(num_classes, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim, num_classes*num_classes, bias=True)
        self.relu = nn.ReLU(inplace=False)


    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out.view(self.num_classes, self.num_classes)


def dirichlet_indices(x, y, trained_clf, num_classes, dirichlet_numchunks=250, non_iid_ness=1., batch_size=200):
    new_indices = []
    min_size = -1
    N = x.size(0)
    min_size_threshold = 5  # hyperparameter.
    while (
        min_size < min_size_threshold
    ):  # prevent any chunk having too less data
        idx_batch = [[] for _ in range(dirichlet_numchunks)]
        idx_batch_cls = [
            [] for _ in range(dirichlet_numchunks)
        ]  # contains data per each class
        for k in range(num_classes):
            targets_np = y.detach().cpu().numpy()
            idx_k = np.where(targets_np == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(
                np.repeat(non_iid_ness, dirichlet_numchunks)
            )

            # balance
            proportions = np.array(
                [
                    p * (len(idx_j) < N / dirichlet_numchunks)
                    for p, idx_j in zip(proportions, idx_batch)
                ]
            )
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))
            ]
            min_size = min([len(idx_j) for idx_j in idx_batch])

            # store class-wise data
            for idx_j, idx in zip(idx_batch_cls, np.split(idx_k, proportions)):
                idx_j.append(idx)

    sequence_stats = []
    # create temporally correlated toy dataset by shuffling classes
    for chunk in idx_batch_cls:
        cls_seq = list(range(num_classes))
        np.random.shuffle(cls_seq)
        for cls in cls_seq:
            idx = chunk[cls]
            new_indices.extend(idx)
            sequence_stats.extend(list(np.repeat(cls, len(idx))))

    num_samples = len(new_indices)
    new_indices = new_indices[:num_samples]

    return new_indices



def make_balanced_intermediate_dataset(x_train, y_train, num_classes):
    # make the balanced intermediate dataset
    py_train = torch.zeros(num_classes)
    for c in range(num_classes):
        py_train[c] = (y_train ==c).int().sum()
    print (py_train)

    if py_train.max() == py_train.min():
        print ('it is already balanced')
        return x_train, y_train
    else:
        x_int_bal, y_int_bal = [], []
        num_ex_per_class = py_train.max()

        for c in range(num_classes):
            idx_c = y_train == c
            temp1 = int(num_ex_per_class // py_train[c] )
            temp2 = int(num_ex_per_class - temp1 * py_train[c])

            x_int_bal.append(x_train[idx_c].repeat(temp1,1,1,1))
            y_int_bal.append(y_train[idx_c].repeat(temp1))

            temp3 = torch.randperm(int(idx_c.float().sum()))[:temp2]

            x_int_bal.append(x_train[idx_c][temp3])
            y_int_bal.append(y_train[idx_c][temp3])

        x_int_bal, y_int_bal = torch.cat(x_int_bal, 0), torch.cat(y_int_bal, 0)
        y_int_bal = y_int_bal.long()
        py_train = torch.zeros(num_classes)
        for c in range(num_classes):
            py_train[c] = (y_int_bal ==c).int().sum()
        print (py_train)

        return x_int_bal, y_int_bal