import math
from copy import deepcopy
import numpy as np
import random

import torch
import torch.utils.data as data
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
from tqdm import tqdm

import common.utils.logging as logging

logger = logging.get_logger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_learnable_metric_k_fold(cfg, source_dataset, target_dataset, metric, meter, n_splits=3, **kwargs):
    source_dataset = data.TensorDataset(*source_dataset)
    target_dataset = data.TensorDataset(*target_dataset)

    div, mod = divmod(len(source_dataset), n_splits)
    source_splits = data.random_split(source_dataset, [div + 1] * mod + [div] * (n_splits - mod))
    div, mod = divmod(len(target_dataset), n_splits)
    target_splits = data.random_split(target_dataset, [div + 1] * mod + [div] * (n_splits - mod))
    metric_val = 0
    for i in range(n_splits):
        source_test_loader = data.DataLoader(source_splits[i], batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
        source_train_loader = data.DataLoader(ConcatDataset([source_splits[j] for j in range(n_splits) if j != i]),
                                            batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
        target_test_loader = data.DataLoader(target_splits[i], batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
        target_train_loader = data.DataLoader(ConcatDataset([target_splits[j] for j in range(n_splits) if j != i]),
                                            batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
        metric_val += compute_learnable_metric(ConcatLoader(source_train_loader, target_train_loader),
                                               ConcatLoader(source_test_loader, target_test_loader),
                                               deepcopy(metric), meter, **kwargs)
    return metric_val / n_splits

def compute_learnable_metric(train_loader, test_loader, metric, meter, lr=0.01, epochs=10, verbose=False):
    optimizer = torch.optim.SGD(metric.parameters(), lr=lr, momentum=0.9,
                    weight_decay=1e-3)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(train_loader), 
                    eta_min=0, last_epoch=-1)
    if verbose:
        logger.info(f"learnable metric {meter.name}: {len(train_loader)} train samples, {len(test_loader)} test samples")
    for e in tqdm(range(epochs), total=epochs, desc=f"learnable metric: {meter.name}"):
        metric.train()
        for (data_s, data_t) in train_loader:
            f_s, g_s, label_s = data_s
            f_t, g_t, label_t = data_t
            loss = metric(g_s.to(device), f_s.to(device), g_t.to(device), f_t.to(device), 
                            label_s.to(device), label_t.to(device), training=True)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

    metric.eval()
    meter.reset()
    with torch.no_grad():
        for (data_s, data_t) in train_loader:
            f_s, g_s, label_s = data_s
            f_t, g_t, label_t = data_t
            acc = metric(g_s.to(device), f_s.to(device), g_t.to(device), f_t.to(device),
                            label_s.to(device), label_t.to(device), training=False)
            meter.update(acc.cpu(), f_t.shape[0])
    if verbose:
        logger.info(f"epoch {e} {meter.name} accuracy: {meter.avg}")
    return meter.avg

# get model outputs of val set
def val_epoch(iter_test, model, max_iter=None):
    all_feature = []
    all_output = []
    all_label = []
    with torch.no_grad():
        for data in tqdm(iter_test, total=len(iter_test), desc="generating data for evaluating metrics"):
            inputs, labels = data
            outputs, feature = model(inputs.to(device))
            all_feature.append(feature.cpu())
            all_output.append(outputs.cpu())
            all_label.append(labels)
            if max_iter is not None and len(all_feature) >= max_iter:
                break
    all_feature = torch.cat(all_feature, dim=0)
    all_output = torch.cat(all_output, dim=0)
    all_label = torch.cat(all_label, dim=0)
    return all_feature, all_output, all_label

def construct_metric_dataloader(cfg, loader, split_ratio=0.7):
    # get model outputs of val set
    source_features, source_outputs, source_labels = loader['source_test_data']
    target_features, target_outputs, target_labels = loader['target_test_data']
    loader['source_test_normalized_data'] = (F.normalize(source_features, dim=1), source_outputs, source_labels)
    loader['target_test_normalized_data'] = (F.normalize(target_features, dim=1), target_outputs, target_labels)

    # construct dataloader for metrics
    dataset = data.TensorDataset(source_features, source_outputs, source_labels)
    loader['metric_source'] = data.DataLoader(dataset, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
    dataset = data.TensorDataset(target_features, target_outputs, target_labels)
    loader['metric'] = data.DataLoader(dataset, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=False)
    
    if 'ACM' in cfg.metrics:
        target_aug_features, target_aug_outputs, target_aug_labels = loader['target_aug_data']
        dataset = data.TensorDataset(target_aug_features, target_aug_outputs, target_aug_labels)
        loader['metric_aug'] = data.DataLoader(dataset, batch_size=cfg.metric_batch_size, shuffle=False, num_workers=cfg.workers, drop_last=False)

    source_dataset = data.TensorDataset(source_features, source_outputs, source_labels)
    length = len(source_dataset)
    source_train_set, source_val_set = data.random_split(source_dataset, [int(split_ratio * length), length - int(split_ratio * length)])
    loader['metric_source_train'] = data.DataLoader(source_train_set, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
    loader['metric_source_test'] = data.DataLoader(source_val_set, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
    target_dataset = data.TensorDataset(target_features, target_outputs, target_labels)
    length = len(target_dataset)
    target_train_set, target_val_set = data.random_split(target_dataset, [int(split_ratio * length), length - int(split_ratio * length)])
    loader['metric_target_train'] = data.DataLoader(target_train_set, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
    loader['metric_target_test'] = data.DataLoader(target_val_set, batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=False)

    loader['metric_train'] = ConcatLoader(loader['metric_source_train'], loader['metric_target_train'])
    loader['metric_test'] = ConcatLoader(loader['metric_source_test'], loader['metric_target_test'])
    
    return loader

class ConcatLoader(object):
    def __init__(self, source_loader, target_loader, batch_size_scale=1):
        self.source_loader = source_loader
        self.target_loader = target_loader
        self.source_iter = iter(source_loader)
        self.target_iter = iter(target_loader)
        self.batch_size_scale = batch_size_scale
        self.cur_pos=-1
    
    def __len__(self):
        return math.floor(len(self.target_loader)/float(self.batch_size_scale))

    def __iter__(self):
        return self

    def __next__(self):
        source_data = []
        target_data = []
        for _ in range(self.batch_size_scale):
            self.cur_pos +=1
            if self.cur_pos >= len(self.target_loader):
                self.source_iter = iter(self.source_loader)
                self.target_iter = iter(self.target_loader)
                self.cur_pos = -1
                raise StopIteration()
            elif (self.cur_pos + 1) % len(self.source_loader) == 0:
                self.source_iter = iter(self.source_loader)

            s_data = self.source_iter.next() # (x, out, label)
            t_data = self.target_iter.next() # (x, out, label)
            if s_data[0].size(0) > t_data[0].size(0):
                s_data = tuple([data[:t_data[0].size(0)] for data in s_data])
            source_data.append(s_data)
            target_data.append(t_data)

        source_data = [torch.cat(data, dim=0) for data in zip(*source_data)]
        target_data = [torch.cat(data, dim=0) for data in zip(*target_data)]
        return source_data, target_data

class ClassConcatDataset(object):
    def __init__(self, source_dataset, target_dataset, num_classes=None):
        self.source_get = source_dataset.__getitem__
        self.target_get = target_dataset.__getitem__
        self.num_classes = target_dataset.num_classes if num_classes is None else num_classes
        if hasattr(source_dataset, 'tensors'):
            self.source_len = len(source_dataset.tensors[0])
            self.target_len = len(target_dataset.tensors[0])
        else:
            self.source_len = len(source_dataset.samples)
            self.target_len = len(target_dataset.samples)

    def __getitem__(self, index):
        if index < self.source_len:
            return self.source_get(index)
        else:
            return self.target_get(index-self.source_len)

    def __len__(self):
        return self.source_len + self.target_len

class ClassBatchSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, batch_size, batch_size_class=1, shuffle=False, seed=0):
        self.dataset = dataset
        self.num_classes = self.dataset.num_classes
        # self.source_class_to_indices = dataset.source_class_to_indices
        self.batch_size = batch_size
        self.batch_size_class = batch_size_class
        assert batch_size % batch_size_class == 0
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

    def __iter__(self):
        if self.batch_size_class == 1:
            for b in range(self.__len__()):
                start, end = b*self.batch_size, min((b+1)*self.batch_size, self.num_samples)
                target_index = [self.source_len+i for i in range(start, end)]
                source_index = []
                for l in self.target_cluster_labels[start:end]:
                    source_index.append(random.choice(self.source_class_to_indices[l]))

                yield source_index + target_index
        
        else:
            for b in range(self.__len__()):
                target_index = []
                source_index = []
                select_class = np.sort(np.random.choice(range(self.num_classes), self.batch_size//self.batch_size_class, replace=True))
                for l in select_class:
                    target_index += list(np.random.choice(self.target_class_to_indices[l], self.batch_size_class, replace=True))
                    source_index += list(np.random.choice(self.source_class_to_indices[l], self.batch_size_class, replace=True))

                yield source_index + target_index
        
    def __len__(self):
        return math.ceil(1.0*self.num_samples / self.batch_size)

    def set_epoch(self, epoch):
        self.epoch = epoch

    def update_label(self, source_labels, target_cluster_labels):
        self.source_class_to_indices = [[] for _ in range(self.num_classes)]
        for i, label in enumerate(source_labels):
            self.source_class_to_indices[label].append(i)
        self.source_len = len(source_labels)
        self.target_cluster_labels = target_cluster_labels
        self.target_class_to_indices = [[] for _ in range(self.num_classes)]
        for i, label in enumerate(self.target_cluster_labels):
            self.target_class_to_indices[label].append(i+self.source_len)
        self.num_samples = len(target_cluster_labels)