"""
Query: a image embedding
Database: embeddings of roads in the city
"""
import os
import argparse
import random
import time
import json
from functools import partial

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, random_split
import faiss
import numpy as np
import pandas as pd

from ..constant import DATA_DIR
from ..utils import choose_device, EarlyStopper


# DATA_DIR = os.path.join(os.path.expanduser("~"), 'workspace', 'roadnet', 'data')


def load_data(dataname, emb_filename, match_distance=50, masked=True, preprocess=True):
    """
    Load data for retrieval task.

    Args:
    - dataname: str, name of the dataset
    - emb_filename: str, filename of road embeddings
    - match_distance: float, distance threshold for matching SVI images to roads
    - masked: whether svi, which are not aligned to roads, are masked.

    Return:
    - road_emb: torch.Tensor, embeddings of roads
    - svi_emb: torch.Tensor, embeddings of SVI images
    - im_md: pd.DataFrame, metadata of SVI images
    - im_2_road_table: list of list, mapping from SVI images to roads. The order is the same as `svi_emb`.
    """
    processed_svi_path = os.path.join(DATA_DIR, dataname, 'data_processed', 'svi_emb')
    road_emb = torch.load(emb_filename)
    svi_emb = torch.load(os.path.join(processed_svi_path, 'svi_embeddings.pt'))
    im_2_road = pd.read_csv(os.path.join(processed_svi_path, f'svi_2_road_{match_distance}.csv'), 
                            usecols=['objectid', 'road_id', 'distance'])
    im_metadata = pd.read_csv(os.path.join(DATA_DIR, dataname, 'data_processed', 'im_metadata.csv'))
    im_metadata['pt_index'] = im_metadata.index
    im_md = im_metadata.merge(im_2_road, on='objectid', how='left')
    # Columns ['pt_index', 'road_id'] give the mapping from image to road.
    im_2_road_table = [[] for _ in range(svi_emb.shape[0])]

    if preprocess:
        # Standardize the road embeddings
        road_emb = (road_emb - road_emb.mean(dim=0)) / road_emb.std(dim=0)
    
    for i, row in im_md.iterrows():
        if not np.isnan(row['road_id']):
            road_id = np.int64(row['road_id'])
            im_2_road_table[np.int64(row['pt_index'])].append(road_id)
    if not masked:
        return road_emb, svi_emb, im_md, im_2_road_table
    mask = [len(r) > 0 for r in im_2_road_table]
    im_2_road_table = [item for item in im_2_road_table if len(item) > 0]
    return road_emb, svi_emb[mask], im_md, im_2_road_table


class SVIEmbDataset(Dataset):
    def __init__(self, svi_emb: Tensor, im_2_road_table: list):
        self.svi_emb = svi_emb
        self.im_2_road_table = im_2_road_table
    
    def __len__(self):
        return self.svi_emb.size(0)
    
    def __getitem__(self, idx):
        return self.svi_emb[idx], self.im_2_road_table[idx]
    
    @staticmethod
    def collate_fn(batch, all_road_ids, scaler=None):
        """
        Use this function to generate positive and negative samples.
        all_road_ids: list of road_ids (int), or an int [0, ..., N)
        """
        emb, road_ids = zip(*batch)
        emb = torch.stack(emb)
        pos_samples = []
        neg_samples = []
        if isinstance(all_road_ids, int):
            for i, road_id in enumerate(road_ids):
                # Positive sample is a random choice from the same road_id
                pos_sample = random.choice(road_id)
                pos_samples.append(pos_sample)
                # Negative sample is a random choice from all other road_ids
                while True:
                    neg_sample = random.randint(0, all_road_ids - 1)
                    if neg_sample not in road_id:
                        break
                neg_samples.append(neg_sample)
        elif isinstance(all_road_ids, list):
            for i, road_id in enumerate(road_ids):
                # Positive sample is a random choice from the same road_id
                pos_sample = random.choice(road_id)
                pos_samples.append(pos_sample)
                # Negative sample is a random choice from all other road_ids
                while True:
                    neg_sample = random.choice(all_road_ids)
                    if neg_sample not in road_id:
                        break
                neg_samples.append(neg_sample)
        else:
            raise ValueError('Invalid input for all_road_ids.')
        
        pos_samples = torch.tensor(pos_samples, dtype=torch.long)
        neg_samples = torch.tensor(neg_samples, dtype=torch.long)
        if scaler is not None:
            emb = scaler.transform(emb)
        return emb, (pos_samples, neg_samples)

    @staticmethod
    def collate_fn_eval(batch, scaler=None):
        emb, road_ids = zip(*batch)
        emb = torch.stack(emb)
        if scaler is not None:
            emb = scaler.transform(emb)
        return emb, road_ids


class ZscoreScaler():
    def __init__(self):
        self.mean = None
        self.std = None
    
    def fit(self, data):
        self.mean = data.mean(dim=0)
        self.std = data.std(dim=0)
        return self
    
    def transform(self, data):
        return (data - self.mean) / self.std
    
    def fit_transform(self, data):
        return self.fit(data).transform(data)


class MLP(nn.Module):
    """
    A MLP model with `num_layers` layers.
    """
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(MLP, self).__init__()
        assert num_layers >= 2
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        for _ in range(num_layers - 2):  # -2 because we're manually adding the input and output layers
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        self.layers.append(nn.Linear(hidden_size, output_size))
        self.relu = nn.ReLU()

    def forward(self, x):
        for i in range(len(self.layers) - 1):  # -1 because we don't want ReLU after the last layer
            x = self.layers[i](x)
            x = self.relu(x)
        x = self.layers[-1](x)  # output layer
        return x


def recall_at_k(ranked_lists: np.array, targets: list, k, verbose=False):
    """
    Compute recall at k for each query.

    Args:
        ranked_lists: 2D list or 2D numpy array, for each query, the ranked list of retrieved items.
        targets: 2D list or 2D numpy array, for each query, the list of target items.
        k: int, the number of top items to consider for computing recall.

    Returns:
        float, the recall at k over all queries.
    """
    assert len(ranked_lists) == len(targets)
    num_queries = len(targets)
    num_hits = 0
    ground_truth_counts = sum([len(t) for t in targets])
    for i, target in enumerate(targets):
        for t in target:
            if t in ranked_lists[i][:k]:
                num_hits += 1
    if not verbose:
        return num_hits / num_queries
    return num_hits / num_queries, num_hits, ground_truth_counts


def mean_reciprocal_rank(ranked_lists: np.array, targets: list, verbose=False):
    """
    Compute mean reciprocal rank for each query.

    Args:
        ranked_lists: 2D list or 2D numpy array, for each query, the ranked list of retrieved items.
        targets: 2D list or 2D numpy array, for each query, the list of target items.

    Returns:
        float, the mean reciprocal rank over all queries.
    """
    assert len(ranked_lists) == len(targets)
    num_queries = len(targets)
    reciprocal_ranks = []
    for i, target in enumerate(targets):
        for j, item in enumerate(ranked_lists[i]):
            if item in target:
                reciprocal_ranks.append(1 / (j + 1))
                break
    mrr = np.mean(reciprocal_ranks)
    if not verbose:
        return mrr
    return mrr, reciprocal_ranks


class RoadRetrievalTrainer():
    START_EPOCHS = 30
    
    def __init__(self, road_emb: Tensor, svi_dim: int, args, top_k = 10, mrr_top_k = 100) -> None:
        self.road_emb = road_emb
        self.index_dim = road_emb.shape[1]
        self.index: faiss.Index = None
        self.svi_dim = svi_dim
        self.args = args

        self.device: torch.device = None
        self.model: nn.Module = None
        self.recall_top_k = top_k
        self.mrr_top_k = max(mrr_top_k, top_k)

    @staticmethod
    def get_args(argv_list = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('--index-type', type=str, default='hnsw', help='Type of index')
        parser.add_argument('--retrieval-metric', type=str, default='l2', 
                            choices=['l2', 'cosine'], help='Metric for retrieval. Recommend cosine similarity to avoid different scales.')
        parser.add_argument('--dataset', type=str, default='singapore', help='Dataset name')
        parser.add_argument('--emb-filename', type=str, default=None, help='Filename of road embeddings')
        parser.add_argument('--match-distance', type=float, default=50, help='Distance threshold for matching SVI images to roads')
        parser.add_argument('--num-layers', type=int, default=2, help='Number of layers in MLP')
        parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss')
        parser.add_argument('--standardize', action='store_true', help='Standardize the embeddings')
        parser.add_argument('--top-k', type=int, default=10, help='Top k for recall')
        parser.add_argument('--mrr-top-k', type=int, default=100, help='Top k for MRR')
        parser.add_argument('--seed', type=int, default=-1, help='Random seed')
        parser.add_argument('--gpu', type=int, default=-1, help='GPU device. -1 for CPU.')
        parser.add_argument('--epochs', type=int, default=500, help='Number of epochs')
        parser.add_argument('--batch-size', type=int, default=8192, help='For easy implementation, we use a large batch size.')
        parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
        parser.add_argument('--patience', type=int, default=20, help='Patience for early stopping')
        parser.add_argument('--print-steps', type=int, default=100, help='Print steps')
        parser.add_argument('--test-ratio', type=float, default=0.2, help='Ratio of test data')
        parser.add_argument('--runs', type=int, default=1, help='Number of runs')
        if argv_list is not None:
            return parser.parse_args(argv_list)
        return parser.parse_args()
    
    def set_env(self, args):
        """Set the environment."""
        if args.seed >= 0:
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed(args.seed)
        self.device = choose_device(args.gpu)
        self.model = None

    def build_index(self, index_type='flat', retrieval_metric='l2'):
        # Choose metric
        if retrieval_metric == 'l2':
            metric = faiss.METRIC_L2
        elif retrieval_metric == 'cosine':
            metric = faiss.METRIC_INNER_PRODUCT
        else:
            raise ValueError(f'Metric {retrieval_metric} is not supported.')
        # Choose index
        if index_type == 'flat':
            self.index = faiss.index_factory(self.index_dim, 'Flat', metric)
        elif index_type == 'hnsw':
            self.index = faiss.index_factory(self.index_dim, 'HNSW64', metric)
        else:
            raise ValueError(f'Index type {index_type} is not supported.')
        print(self.index.is_trained)
        self.index.add(self.road_emb)
        return self.index

    def train(self, dataset: Dataset, scaler=None, verbose=True):
        # Can split the data loading and preprocessing into other functions or classes.
        args = self.args
        device = self.device
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, 
                                collate_fn=partial(SVIEmbDataset.collate_fn, 
                                                   all_road_ids=self.road_emb.size(0),
                                                   scaler=scaler),
                                num_workers=4, pin_memory=True)
        model = MLP(self.svi_dim, self.index_dim, self.index_dim, args.num_layers)
        if args.retrieval_metric == 'l2':
            criterion = nn.TripletMarginLoss(margin=args.margin, p=2, eps=1e-7)
        elif args.retrieval_metric == 'cosine':
            criterion = nn.TripletMarginLoss(margin=0.5, p=2, eps=1e-7)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
        stopper = EarlyStopper(patience=args.patience)
        model = model.to(device)
        model.train()

        global_step = 0
        for epoch in range(args.epochs):
            loss_values = []
            iter_step = 0
            epoch_time_start = time.time()
            for svi_emb, (pos_samples, neg_samples) in dataloader:
                iter_time_start = time.time()
                svi_emb = svi_emb.to(device)
                pos_emb = self.road_emb[pos_samples].to(device)
                neg_emb = self.road_emb[neg_samples].to(device)
                anchor = model(svi_emb)
                loss = criterion(anchor, pos_emb, neg_emb)
                iter_loss = loss.item()
                loss_values.append(iter_loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                global_step += 1
                iter_step += 1
                iter_time_end = time.time()
                if verbose and global_step % args.print_steps == 0:
                    print(f'Epoch {epoch} | Step {iter_step} | Mini-batch Loss {iter_loss:.4f} | Iter Time {iter_time_end - iter_time_start:.2f}s')
                    
            epoch_loss = sum(loss_values) / len(loss_values)
            if args.patience > 0 and epoch >= self.START_EPOCHS and stopper.step(epoch_loss, model):
                break
            epoch_time_end = time.time()
            if verbose:
                print(f'Epoch {epoch} | Epoch Loss {epoch_loss:.4f} | Epoch Time {epoch_time_end - epoch_time_start:.2f}s')
        
        state_dict = stopper.load_checkpoint()
        if state_dict is not None:
            model.load_state_dict(state_dict)
        self.model = model
        return model

    @torch.no_grad()
    def evaluate(self, dataset: Dataset, model, scaler=None):
        args = self.args
        device = self.device
        # model = self.model
        index = self.build_index(args.index_type, args.retrieval_metric)
        model.eval()
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 
                                collate_fn=partial(SVIEmbDataset.collate_fn_eval,
                                                   scaler=scaler),
                                num_workers=4, pin_memory=True)
        num_hits = 0
        num_hits_2 = 0
        ground_truth_counts = 0
        ground_truth_counts_2 = 0
        reciprocal_ranks = []
        for svi_emb, road_ids in dataloader:
            svi_emb = svi_emb.to(device)
            anchor = model(svi_emb)
            anchor = anchor.cpu().numpy()
            D, I = self.index.search(anchor, self.mrr_top_k)
            _, batch_hits, batch_ground_truth_counts = recall_at_k(I, road_ids, self.recall_top_k, verbose=True)
            _, batch_hits_2, batch_ground_truth_counts_2 = recall_at_k(I, road_ids, self.recall_top_k * 2, verbose=True)
            num_hits += batch_hits
            num_hits_2 += batch_hits_2
            ground_truth_counts += batch_ground_truth_counts
            ground_truth_counts_2 += batch_ground_truth_counts_2
            batch_mrr, batch_reciprocal_ranks = mean_reciprocal_rank(I, road_ids, verbose=True)
            reciprocal_ranks.extend(batch_reciprocal_ranks)
        # recall = num_hits / ground_truth_counts
        # mrr = np.mean(reciprocal_ranks)
        recall = num_hits / len(dataset)
        recall_2 = num_hits_2 / len(dataset)
        mrr = np.mean(reciprocal_ranks)
        results = {f'recall@{self.recall_top_k}': recall, 
                   f'recall@{self.recall_top_k * 2}': recall_2, 
                   'mrr': mrr}
        return results

    def train_and_evaluate(self, dataset: Dataset):
        args = self.args
        runs = args.runs
        data_size = len(dataset)
        for _ in range(runs):
            train_size = int(data_size * (1 - args.test_ratio))
            test_size = data_size - train_size
            train_set, test_set = random_split(dataset, [train_size, test_size])
            if args.standardize:
                scaler = ZscoreScaler().fit(train_set.dataset.svi_emb[train_set.indices])
            else:
                scaler = None
            model = self.train(train_set, scaler)
            results = self.evaluate(test_set, model, scaler)
            print(f'Total results: {json.dumps(results, indent=4)}')


def road_retrieval(argv_list=None):
    args = RoadRetrievalTrainer.get_args(argv_list)
    print(args)
    road_emb, svi_emb, im_md, im_2_road_table = load_data(
        args.dataset, args.emb_filename, match_distance=args.match_distance, preprocess=args.standardize)
    dataset = SVIEmbDataset(svi_emb, im_2_road_table)
    trainer = RoadRetrievalTrainer(road_emb, svi_emb.shape[1], args=args)
    trainer.set_env(args)
    trainer.train_and_evaluate(dataset)


if __name__ == '__main__':
    pass
