import torch
import torch.nn.functional as F
import time
import copy
import random
from torch.utils.data import DataLoader
import networks 

class BaseNet(torch.nn.Module):
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(BaseNet, self).__init__()
        self.hparams = hparams
        self.num_classes = num_classes
        self.encoder = None
        self.decoder = None
        self.z_prior = None
        self.optimizer = None

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std) 

        return mu + eps * std

    def forward(self, x, y=None, d=None):
        raise NotImplementedError

    def update(self, minibatches, unlabeled=None):
        raise NotImplementedError

    def get_match_var(self, x):
        raise NotImplementedError

    def reconstruct(self, x, y=None):
        raise NotImplementedError

    def inference(self, z, y=None):
        raise NotImplementedError
    
    def get_z(self, x):
        raise NotImplementedError



class MyNet(BaseNet):
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(MyNet, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)

        self.num_domains = num_domains
        try:
            self.encoder = networks.Encoder(input_shape, num_classes, num_domains, 
                            self.hparams, conditional=True, use_mlp=True)
            self.decoder = networks.Decoder(input_shape, num_classes, 
                            self.hparams, conditional=True)
            self.z_prior = networks.Cond_Prior(num_classes, num_domains, 
                            self.hparams, domain_vary=True, 
                            distribution = hparams['distribution'])
            self.optimizer = torch.optim.Adam(
                self.parameters(),
                lr=self.hparams["lr"],
                weight_decay=self.hparams['weight_decay']
            )
        except:
            self.encoder = None
            self.decoder = None
            self.z_prior = None
            self.optimizer = None

    def forward(self, x, y, d):
        
        prior_means, prior_log_var = self.z_prior(y, d)
        means, log_var = self.encoder(x, y, d)
        z = self.reparameterize(means, log_var)
        recon_x = self.decoder(z, y)

        return recon_x, means, log_var, z, prior_means, prior_log_var

    def update(self, minibatches, unlabeled=None, use_cf=False):
        mse = 0.0
        kld = 0.0
        matched_cf_mse = 0.0
        num_envs = len(minibatches)

        for d, data in enumerate(minibatches):
            if use_cf:
                x,y,cf_x,cf_y = data
            else:
                x,y = data
            batch_size = x.size(0)
            recon_x, means, log_var, z, prior_means, prior_log_var = self.forward(x, y, d)

            mse += 0.5 * F.mse_loss(recon_x.view(batch_size, -1), 
                    x.view(batch_size, -1), reduction='sum') / (num_envs*batch_size)
            kld += -0.5 * torch.sum(1 + log_var - prior_log_var - 
                    (log_var.exp() + (means - prior_means).pow(2))
                    /prior_log_var.exp()) / (num_envs*batch_size)
            if use_cf:
                matched_mask = cf_y != -1
                if torch.all(~matched_mask):
                    continue
                matched_cf_x = cf_x[matched_mask]
                matched_cf_y = cf_y[matched_mask]
                matched_z = z[matched_mask]
                matched_batch_size = len(matched_cf_x)
                matched_cf_result = self.inference(matched_z, matched_cf_y)
                matched_cf_mse += 0.5 * F.mse_loss(
                    matched_cf_result['recon_x'].view(matched_batch_size, -1), 
                    matched_cf_x.view(matched_batch_size, -1), reduction='sum'
                    ) / (num_envs*matched_batch_size)
        
        loss = mse + matched_cf_mse + self.hparams['kl_weight'] * kld

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if matched_cf_mse > 0:
            return {'loss': loss.item(), 'mse': mse.item(), 'kld': kld.item(), 
                    'cf_mse': matched_cf_mse.item()}
        else:
            return {'loss': loss.item(), 'mse': mse.item(), 'kld': kld.item(), 
                    'cf_mse': matched_cf_mse}

    def get_z(self, x, y, d):
        means, log_var = self.encoder(x, y, d)
        return means

    def inference(self, z, y):

        recon_x = self.decoder(z, y)

        return {'recon_x': recon_x}

    def reconstruct(self, x, y, d):

        means, log_var = self.encoder(x, y, d)
        recon_x = self.decoder(means, y)

        return recon_x
    
    def get_match_var(self, x, y, d):
        return self.get_z(x, y, d)

    def normal_log_pdf(self, mean, log_var, z):
        log_prob = -0.5*log_var - 0.5*((z-mean)/torch.sqrt(torch.exp(log_var)))**2
        return torch.sum(log_prob, dim=-1)

    def propensity_score(self, x, d, p_y=None): 
        log_p_z_y = []
        for i in range(self.num_classes):
            y = torch.tensor([i]*len(x), dtype=torch.int64, device=x.device)
            z = self.get_z(x, y, d)
            prior_means, prior_log_var = self.z_prior(y, d)
            log_p_z_y.append(self.normal_log_pdf(prior_means, prior_log_var, z))
        log_p_z_y = torch.stack(log_p_z_y, dim=1)

        p_y_z = []
        for i in range(self.num_classes):
            diff = torch.exp(log_p_z_y - log_p_z_y[:,i].unsqueeze(1)) 
            if p_y is not None:
                diff = diff * (p_y/p_y[i]).unsqueeze(0)
            p_y_z.append(1/torch.sum(diff, dim=1))
        p_y_z = torch.stack(p_y_z, dim=1)
        p_y_z[p_y_z < 1e-30] = 1e-30 
        return p_y_z

class DatasetSplitter(torch.utils.data.Dataset):
    def __init__(self, underlying_dataset, keys):
        super(DatasetSplitter, self).__init__()
        self.underlying_dataset = underlying_dataset
        self.keys = keys

    def __getitem__(self, key):
        return self.underlying_dataset[self.keys[key]]
        
    def __len__(self):
        return len(self.keys)

class IndexWrapper:
    def __init__(self, dataset, use_raw_index=False):
        super().__init__()
        self.dataset = dataset
        if use_raw_index and isinstance(self.dataset, DatasetSplitter):
            self.idx_map = self.dataset.keys
        else:
            self.idx_map = None
    
    def __getitem__(self, index):
        x, y = self.dataset[index]
        if self.idx_map is not None:
            index = self.idx_map[index]
        return index, x, y

    def __len__(self):
        return len(self.dataset)
    

def matrix_distance(x1, x2, func="l2", device='cuda'):
    flatten_x1 = x1.view(len(x1), -1).to(device)
    flatten_x2 = x2.view(len(x2), -1).to(device)
    if func == 'l1':
        try:
            return torch.cdist(flatten_x1, flatten_x2, p=1)
        except:
            try:
                result = []
                for i, x1_i in enumerate(flatten_x1):
                    sim = torch.abs(x1_i.unsqueeze(0) - flatten_x2).sum(-1)
                    result.append(sim)
                return torch.stack(result, 0)
            except:
                try:
                    result = []
                    for i, x2_i in enumerate(flatten_x2):
                        sim = torch.abs(x2_i.unsqueeze(0) - flatten_x1).sum(-1)
                        result.append(sim)
                    return torch.stack(result, 1)
                except:
                    result = torch.zeros([len(x1), len(x2)], device=device)
                    for i, x1_i in enumerate(flatten_x1):
                        for j, x2_j in enumerate(flatten_x2):
                            sim = torch.abs(x1_i - x2_j).sum()
                            result[i][j] = sim
                    return result
    elif func == 'l2':
        try:
            return torch.cdist(flatten_x1, flatten_x2, p=2)
        except:
            try:
                result = []
                for i, x1_i in enumerate(flatten_x1):
                    sim = (x1_i.unsqueeze(0) - flatten_x2).pow(2).sum(-1).sqrt()
                    result.append(sim)
                return torch.stack(result, 0)
            except:
                try:
                    result = []
                    for i, x2_i in enumerate(flatten_x2):
                        sim = (x2_i.unsqueeze(0) - flatten_x1).pow(2).sum(-1).sqrt()
                        result.append(sim)
                    return torch.stack(result, 1)
                except:
                    result = torch.zeros([len(x1), len(x2)], device=device)
                    for i, x1_i in enumerate(flatten_x1):
                        for j, x2_j in enumerate(flatten_x2):
                            sim = (x1_i - x2_j).pow(2).sum().sqrt()
                            result[i][j] = sim
                    return result
    elif func == 'linf':
        try:
            return torch.cdist(flatten_x1, flatten_x2, p=float('inf'))
        except:
            try:
                result = []
                for i, x1_i in enumerate(flatten_x1):
                    sim = torch.max(torch.abs(x1_i.unsqueeze(0) - flatten_x2), dim=-1)
                    result.append(sim)
                return torch.stack(result, 0)
            except:
                try:
                    result = []
                    for i, x2_i in enumerate(flatten_x2):
                        sim = torch.max(torch.abs(x2_i.unsqueeze(0) - flatten_x1), dim=-1)
                        result.append(sim)
                    return torch.stack(result, 1)
                except:
                    result = torch.zeros([len(x1), len(x2)], device=device)
                    for i, x1_i in enumerate(flatten_x1):
                        for j, x2_j in enumerate(flatten_x2):
                            sim = torch.max(torch.abs(x1_i - x2_j))
                            result[i][j] = sim
                    return result
    elif func[0]=='l' and len(func)==2:
        p = int(func[1])
        try:
            return torch.cdist(flatten_x1, flatten_x2, p=p)
        except:
            try:
                result = []
                for i, x1_i in enumerate(flatten_x1):
                    diff = x1_i.unsqueeze(0) - flatten_x2
                    sim = torch.sign(diff) * torch.abs(diff).pow(p).sum(-1).pow(1/p)
                    result.append(sim)
                return torch.stack(result, 0)
            except:
                try:
                    result = []
                    for i, x2_i in enumerate(flatten_x2):
                        diff = x2_i.unsqueeze(0) - flatten_x1
                        sim = torch.sign(diff) * torch.abs(diff).pow(p).sum(-1).pow(1/p)
                        result.append(sim)
                    return torch.stack(result, 1)
                except:
                    result = torch.zeros([len(x1), len(x2)], device=device)
                    for i, x1_i in enumerate(flatten_x1):
                        for j, x2_j in enumerate(flatten_x2):
                            diff = x1_i - x2_j
                            sim = torch.sign(diff) * torch.abs(diff).pow(p).sum().pow(1/p)
                            result[i][j] = sim
                    return result
    elif func == 'kld': 
        try:
            temp1 = flatten_x1.unsqueeze(1).repeat(1, len(x2), 1)
            temp2 = torch.transpose(flatten_x2.unsqueeze(1).repeat(1, len(x1), 1), 0, 1)
            return torch.sum(temp1*torch.log(temp1/temp2), -1)
        except:
            try:
                result = []
                for i, x1_i in enumerate(flatten_x1):
                    x1_i = x1_i.unsqueeze(0)
                    sim = torch.sum(x1_i*torch.log(x1_i/flatten_x2), -1)
                    result.append(sim)
                return torch.stack(result, 0)
            except:
                try:
                    result = []
                    for i, x2_i in enumerate(flatten_x2):
                        x2_i = x2_i.unsqueeze(0)
                        sim = torch.sum(flatten_x1*torch.log(flatten_x1/x2_i), -1)
                        result.append(sim)
                    return torch.stack(result, 1)
                except:
                    result = torch.zeros([len(x1), len(x2)], device=device)
                    for i, x1_i in enumerate(flatten_x1):
                        for j, x2_j in enumerate(flatten_x2):
                            sim = torch.sum(x1_i*torch.log(x1_i/x2_j))
                            result[i][j] = sim
                    return result

    else:
        exit(1)

class MatchWrapper:
    def __init__(self, dataset, cf_ids, num_cf, use_raw_index=False):
        super().__init__()
        self.dataset = dataset
        self.cf_ids = cf_ids
        self.num_classes = len(self.cf_ids[0]) + 1
        self.num_cf = num_cf
        self.use_raw_index = use_raw_index
        if use_raw_index and isinstance(self.dataset, DatasetSplitter):
            self.raw_dataset = self.dataset.underlying_dataset
            self.idx_map = self.dataset.keys

    def __getitem__(self, index):
        x, y = self.dataset[index]
        x_list = [x]
        y_list = [y]
        chosen_classes = random.sample(list(range(len(self.cf_ids[0]))), self.num_cf)
        
        if self.use_raw_index:
            for c in chosen_classes:
                cf_id = self.cf_ids[self.idx_map[index]][c]  

                if len(cf_id) > 0:
                    _cf_id = cf_id[0]
                    cf_x, cf_y = self.raw_dataset[_cf_id]

                    if cf_y == y:
                        raise Exception
                        exit(1)
                    elif cf_y < y:
                        assert cf_y == c
                    else:
                        assert cf_y == c + 1

                else:
                    _cf_id = None
                    cf_x, cf_y = torch.zeros_like(x, device=x.device), \
                        y - y - 1

                x_list.append(cf_x)
                y_list.append(cf_y)

        else:
            for c in chosen_classes:
                cf_id = self.cf_ids[index][c]

                if len(cf_id) > 0:
                    _cf_id = cf_id[0]
                    cf_x, cf_y = self.dataset[_cf_id]

                    if cf_y == y:
                        raise Exception
                        exit(1)
                    elif cf_y < y:
                        assert cf_y == c
                    else:
                        assert cf_y == c + 1

                else:
                    _cf_id = None
                    cf_x, cf_y = torch.zeros_like(x, device=x.device), \
                        y - y - 1

                x_list.append(cf_x)
                y_list.append(cf_y)

        return torch.stack(x_list, 0), torch.tensor(y_list)

    def __len__(self):
        return len(self.dataset)


def train_with_balence_sampling(in_splits, matcher, num_classes, batch_size=32,
                dist_func='l2', device="cuda", verbose=True, use_raw_index=False, 
                threshold=False, topk=1, num_cf=1):
    torch.multiprocessing.set_sharing_strategy('file_system')
    all_vars = [] 
    wrapped_datasets = []
    for env, (dataset, weight) in enumerate(in_splits):
        x = [[] for i in range(num_classes)]
        ids = [[] for i in range(num_classes)]
        vars = []

        dataloader = DataLoader(IndexWrapper(dataset, use_raw_index), 
                batch_size=batch_size, shuffle=False, 
                num_workers=4, drop_last=False)
        y_counts = torch.zeros(num_classes, dtype=torch.int64, device=device)
        for index_batch, x_batch, y_batch in dataloader:
            index_batch_, x_batch_, y_batch_ = \
                copy.deepcopy(index_batch), copy.deepcopy(x_batch), copy.deepcopy(y_batch)
            del index_batch, x_batch, y_batch
            for i, x_, y_ in zip(index_batch_, x_batch_, y_batch_):
                y_counts[int(y_)] += 1
                x[y_].append(x_)
                ids[y_].append(i)
        del dataloader 

        x = [torch.stack(x_y, 0) for x_y in x]
        y = [torch.zeros(len(x[i]), dtype = torch.int64) + i for i in range(num_classes)]
        p_y = y_counts/torch.sum(y_counts)


        matcher.eval()
        with torch.no_grad():
            for x_, y_ in zip(x,y):
                i = 0
                temp_vars = []
                while i < len(x_):
                    j = min(i+batch_size, len(x_))
                    temp_vars.append(matcher.propensity_score(x_[i:j].to(device),
                                    env, p_y).detach())
                    i = j
                temp_vars = torch.cat(temp_vars, 0)
                vars.append(temp_vars)
        matcher.train()
        all_vars.append(vars)

       
        most_simialr_id = {}
        for i in range(num_classes-1):
            for j in range(i+1, num_classes):
                dist_ij = matrix_distance(vars[i], vars[j], func=dist_func, device=device)
                if dist_func == 'kld':
                    dist_ji = matrix_distance(vars[j], vars[i], func=dist_func, device=device)
                else:
                    dist_ji = torch.transpose(dist_ij, 0, 1)

                kj = min(topk, len(vars[j]))
                min_vals_ij, min_ids_ij = torch.topk(
                    dist_ij, k=kj, dim=1, largest=False)
                ki = min(topk, len(vars[i]))
                min_vals_ji, min_ids_ji = torch.topk(
                    dist_ji, k=ki, dim=1, largest=False)

                if threshold:
                    threshold_ij = torch.mean(dist_ij)
                    mask_ij = min_vals_ij < threshold_ij

                    threshold_ji = torch.mean(dist_ji)
                    mask_ji = min_vals_ji < threshold_ji

                    most_simialr_id[(i,j)] = []
                    for k in range(len(vars[i])):
                        most_simialr_id[(i,j)].append(min_ids_ij[k][mask_ij[k]].to('cpu'))

                    most_simialr_id[(j,i)] = []
                    for k in range(len(vars[j])):
                        most_simialr_id[(j,i)].append(min_ids_ji[k][mask_ji[k]].to('cpu'))
                else:
                    most_simialr_id[(i,j)], most_simialr_id[(j,i)] = \
                        min_ids_ij.to('cpu'), min_ids_ji.to('cpu')



        
        if use_raw_index and isinstance(dataset, DatasetSplitter):
            cf_ids = [[] for _ in range(len(dataset.underlying_dataset))]
        else:
            cf_ids = [[] for _ in range(len(dataset))]
        for i in range(num_classes):
            for j in range(num_classes):
                if j != i:
                    for k in range(len(x[i])):
                        temp = []
                        for g in most_simialr_id[(i,j)][k]:
                            temp.append(ids[j][g])
                        cf_ids[ids[i][k]].append(temp)

        wrapped_datasets.append((MatchWrapper(dataset, cf_ids, num_cf, use_raw_index), weight))
    if verbose:
        return wrapped_datasets, all_vars
    else:
        return wrapped_datasets
