import torch
import numpy as np
import copy
import random 
import argparse
import os
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.models as models 

import tent
import copy
from pathlib import Path
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import math 
import matplotlib.pyplot as plt


def compute_mask(_loader, _net, _args):
    ## compute mask
    _net_adapt = copy.deepcopy(_net)
    _net_adapt = configure_model_noadapt(_net_adapt)
    np.random.seed(_args.seed)
    torch.manual_seed(_args.seed)
    random.seed(_args.seed)
    acc_bnadapt_tr, logits_tr, labels_tr = tta_noadapt(_loader, _net_adapt, _args, _args.int_batch_size)
    del _net_adapt
    from sklearn.metrics import confusion_matrix
    con_mat = confusion_matrix(labels_tr, torch.squeeze(logits_tr.argmax(-1)).float())
    con_mat_norm = con_mat / con_mat.sum(1, keepdims=True)

    mask = torch.zeros(_args.num_classes, _args.num_classes)
    for i in range(_args.num_classes):
        for j in range(_args.num_classes):
            if con_mat_norm[i,j] > 0.:
                mask[i,j] = 1 
                mask[j,i] = 1
        
        mask[i,i] = 1
    return mask

def get_img_num_per_cls(num_examples, cls_num, imb_factor, imb_type='exp'):
        img_max = num_examples / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls


def make_LT_datasets(x, y, rho, num_classes):
    x_mod = []
    y_mod = []
    
    num_cls = get_img_num_per_cls(x.size(0), num_classes, rho)

    for c in range(num_classes):
        x_c = x[y==c]
        y_c = y[y==c]

        idx_c = torch.randperm(x.size(0)//num_classes)[:num_cls[c]]
        x_mod.append(x_c[idx_c])
        y_mod.append(y_c[idx_c])
    x_mod, y_mod = torch.cat(x_mod, 0), torch.cat(y_mod, 0)
    return x_mod, y_mod, num_cls

def make_inverse_LT_datasets(x, y, rho, num_classes):
    x_mod = []
    y_mod = []
    
    num_cls = get_img_num_per_cls(x.size(0), num_classes, rho)

    for c in range(num_classes):
        x_c = x[y==c]
        y_c = y[y==c]

        idx_c = torch.randperm(x.size(0)//num_classes)[:num_cls[(num_classes-1)-c]]
        x_mod.append(x_c[idx_c])
        y_mod.append(y_c[idx_c])
    x_mod, y_mod = torch.cat(x_mod, 0), torch.cat(y_mod, 0)
    return x_mod, y_mod, num_cls

class MLP(nn.Module):
    def __init__(self, num_classes, hiddendim):
        super().__init__()
        self.num_classes = num_classes
        self.fc1 = nn.Linear(num_classes, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim, num_classes*num_classes, bias=True)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out.view(self.num_classes, self.num_classes)

class MLP_diag(nn.Module):
    def __init__(self, num_classes, hiddendim):
        super().__init__() 
        self.num_classes = num_classes
        self.fc1 = nn.Linear(num_classes, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim, num_classes, bias=True)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out.view(self.num_classes)

class MLP_masked(nn.Module):
    def __init__(self, num_classes, mask, hiddendim, device):
        super().__init__() 
        self.num_learnable_params = mask.float().sum().int()
        # check = torch.where(mask == 1, True, False) #[K, K]
        self.check =mask.nonzero()#(mask==1).nonzero(as_tuple=True)


        self.num_classes = num_classes
        self.fc1 = nn.Linear(num_classes, hiddendim, bias=True)
        self.fc2 = nn.Linear(hiddendim,  self.num_learnable_params, bias=True)
        self.relu = nn.ReLU(inplace=False)
        self.device = device

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)

        # arange 
        outputs = torch.zeros(self.num_classes, self.num_classes).to(self.device)
        outputs[self.check[:, 0], self.check[:, 1]] = out

        return outputs


def configure_model_bn(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # print ('utils_dart.py//line197//', m.momentum)
            # force use of batch stats in train and eval modes

            # m.track_running_stats = False
            # m.running_mean = None
            # m.running_var = None
    return model

def configure_model_noadapt(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.eval()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)

    return model

def configure_model_whole(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(True)

    return model

def configure_model_tent(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes

            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
    return model


def tta_noadapt(_test_loader, _model, args, bs):
    _model_adapt = copy.deepcopy(_model)
    _model_adapt = configure_model_noadapt(_model_adapt)
    _model_adapt.eval()

    acc = 0.
    num_examples = 0
    logits_bank = torch.zeros(len(_test_loader.dataset), args.num_classes)
    labels_bank = torch.zeros(len(_test_loader.dataset)).long()
    
    with torch.no_grad():
        for counter, batch_data in enumerate(_test_loader):
            x_curr, y_curr = batch_data 
            x_curr, y_curr = x_curr.cuda(), y_curr.cuda()
            y_curr = y_curr.type(torch.cuda.LongTensor)

            num_examples += x_curr.shape[0]

            outputs = _model_adapt(x_curr)

            acc += (outputs.max(1)[1] == y_curr).float().sum()

            if counter % 100 == 99:
                print ('line219', counter, "%.4f"%(acc.item()/num_examples))

            logits_bank[counter*bs: (counter+1)*bs] = outputs.detach().cpu()
            labels_bank[counter*bs: (counter+1)*bs] = y_curr.detach().cpu()
    
    return acc.item() / num_examples, logits_bank, labels_bank


def tta_noadapt_samples(_x, _y, _model, args):
    _model_adapt = copy.deepcopy(_model)
    _model_adapt = configure_model_noadapt(_model_adapt)
    _model_adapt.eval()

    acc = 0.
    num_examples = 0
    bs = 64
    logits_bank = torch.zeros(_x.size(0), args.num_classes)
    labels_bank = torch.zeros(_x.size(0)).long()
    n_batches = math.ceil(_x.size(0) / bs)
    with torch.no_grad():
        for counter in range(n_batches):
            x_curr = _x[counter*bs:(counter+1)*bs]
            y_curr = _y[counter*bs:(counter+1)*bs]
            x_curr, y_curr = x_curr.cuda(), y_curr.cuda()
            y_curr = y_curr.type(torch.cuda.LongTensor)

            num_examples += x_curr.shape[0]

            outputs = _model_adapt(x_curr)

            acc += (outputs.max(1)[1] == y_curr).float().sum()

            if counter % 100 == 99:
                print ('line219', counter, "%.4f"%(acc.item()/num_examples))

            logits_bank[counter*bs: (counter+1)*bs] = outputs.detach().cpu()
            labels_bank[counter*bs: (counter+1)*bs] = y_curr.detach().cpu()
    
    return acc.item() / num_examples, logits_bank, labels_bank

def tta_bnadapt(_test_loader, _model, args, bs):
    _model_adapt = copy.deepcopy(_model)
    _model_adapt = configure_model_bn(_model_adapt)
    _model_adapt.train()

    acc = 0.
    num_examples = 0
    logits_bank = torch.zeros(len(_test_loader.dataset), args.num_classes)
    labels_bank = torch.zeros(len(_test_loader.dataset)).long()
    
    with torch.no_grad():
        for counter, batch_data in enumerate(_test_loader):
            x_curr, y_curr, _ = batch_data 
            x_curr, y_curr = x_curr.cuda(), y_curr.cuda()
            y_curr = y_curr.type(torch.cuda.LongTensor)
            num_examples += x_curr.shape[0]

            outputs = _model_adapt(x_curr)

            acc += (outputs.max(1)[1] == y_curr).float().sum()

            if counter % 100 == 99:
                print ('line219', counter, "%.4f"%(acc.item()/num_examples))

            logits_bank[counter*bs: (counter+1)*bs] = outputs.detach().cpu()
            labels_bank[counter*bs: (counter+1)*bs] = y_curr.detach().cpu()
    
    return acc.item() / num_examples, logits_bank, labels_bank

def tta_bnadapt_sample(_x,_y, _model, args):
    _model_adapt = copy.deepcopy(_model)
    _model_adapt = configure_model_bn(_model_adapt)
    _model_adapt.train()

    acc = 0.
    num_examples = 0
    print (_x.size(0))
    logits_bank = torch.zeros(_x.size(0), 1000)
    labels_bank = torch.zeros(_x.size(0)).long()
    n_batches = math.ceil(_x.size(0) / args.batch_size)
    perm = torch.randperm(_x.size(0))
    with torch.no_grad():
        for counter in range(n_batches):
            idx_curr = perm[args.batch_size * (counter):args.batch_size * (counter+1)]
            x_curr = _x[idx_curr]
            y_curr = _y[idx_curr]
            x_curr, y_curr = x_curr.cuda(), y_curr.cuda()
            y_curr = y_curr.type(torch.cuda.LongTensor)

            num_examples += x_curr.shape[0]

            outputs = _model_adapt(x_curr)

            logits_bank[idx_curr] = outputs.detach().cpu()
            labels_bank[idx_curr] = y_curr.cpu()
            acc += (outputs.max(1)[1] == y_curr).float().sum()

            if counter % 100 == 99:
                print ('line264', counter, "%.4f"%(acc.item()/num_examples))
    return acc.item() / num_examples, logits_bank, labels_bank



def dirichlet_indices(x, y, num_classes, dirichlet_numchunks=250, non_iid_ness=1., batch_size=200):
    new_indices = []
    min_size = -1
    N = x.size(0)
    min_size_threshold = 0  # hyperparameter.
    while (
        min_size < min_size_threshold
    ):  # prevent any chunk having too less data
        idx_batch = [[] for _ in range(dirichlet_numchunks)]
        idx_batch_cls = [
            [] for _ in range(dirichlet_numchunks)
        ]  # contains data per each class
        for k in range(num_classes):
            targets_np = y.detach().cpu().numpy()#targets_np = torch.Tensor(y).numpy()
            idx_k = np.where(targets_np == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(
                np.repeat(non_iid_ness, dirichlet_numchunks)
            )

            # balance
            proportions = np.array(
                [
                    p * (len(idx_j) < N / dirichlet_numchunks)
                    for p, idx_j in zip(proportions, idx_batch)
                ]
            )
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))
            ]
            min_size = min([len(idx_j) for idx_j in idx_batch])

            # store class-wise data
            for idx_j, idx in zip(idx_batch_cls, np.split(idx_k, proportions)):
                idx_j.append(idx)

    sequence_stats = []
    # create temporally correlated toy dataset by shuffling classes
    for chunk in idx_batch_cls:
        cls_seq = list(range(num_classes))
        np.random.shuffle(cls_seq)
        for cls in cls_seq:
            idx = chunk[cls]
            new_indices.extend(idx)
            sequence_stats.extend(list(np.repeat(cls, len(idx))))

    num_samples = len(new_indices)
    new_indices = new_indices[:num_samples]

    return new_indices

def tta_bnadapt_w_gphi(logits_bank, labels_bank, _model, args, bs, _g_phi):
    _model_adapt = copy.deepcopy(_model)
    _model_adapt = configure_model_bn(_model_adapt)
    _model_adapt.train()

    acc = 0.
    acc_new = 0.
    num_examples = logits_bank.size(0)
    n_batches = math.ceil(logits_bank.size(0) / bs)
    
    with torch.no_grad():
        for counter in range(n_batches):
            logits_curr = logits_bank[counter*bs:(counter+1)*bs]
            y_curr = labels_bank[counter*bs:(counter+1)*bs].long()
            outputs = logits_curr

            with torch.no_grad():
                T_curr = _g_phi(torch.softmax(outputs,dim=1).mean(0).cuda()).cpu()

            acc += (outputs.max(1)[1] == y_curr).float().sum()
            acc_new += ((outputs@T_curr).max(1)[1] == y_curr).float().sum()

    del _model_adapt
    return acc.item() / num_examples, acc_new.item()/num_examples
