import copy
import json
import os
import random
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm, trange

import data
import utils
from algorithm import init_algorithm


def run_epoch(algorithm, loader, train, progress_bar=True):

    epoch_labels = []
    epoch_logits = []
    epoch_group_ids = []

    if progress_bar:
        loader = tqdm(loader, desc=f'{"train" if train else "eval"} loop')

    if not train:
        # We need to ensure that the BN statistics are updated only for
        # the current evaluation task
        # and not reused for the next evaluation tasks
        if 'ARM_BN' in str(algorithm) or 'CXDA' in str(algorithm):
            model_state_dict = copy.deepcopy(algorithm.model.state_dict())

    for images, labels, group_ids in loader:
        # Unpack the support and query images
        support_images = images[0]
        query_images = images[1]
        support_labels = labels[0]
        query_labels = labels[1]
        support_group_ids = group_ids[0]
        query_group_ids = group_ids[1]

        # Find the order of support and query images
        support_set_idx = [e for e in range(
            support_images.shape[0] + query_images.shape[0]) if (e // algorithm.support_size) % 2 == 0]
        query_set_idx = [e for e in range(
            support_images.shape[0] + query_images.shape[0]) if (e // algorithm.support_size) % 2 == 1]

        images = torch.cat(images).to(algorithm.device)
        images[support_set_idx] = support_images.to(algorithm.device)
        images[query_set_idx] = query_images.to(algorithm.device)

        labels = torch.cat(labels).to(algorithm.device)
        labels[support_set_idx] = support_labels.to(algorithm.device)
        labels[query_set_idx] = query_labels.to(algorithm.device)

        group_ids = torch.cat(group_ids)
        group_ids[support_set_idx] = support_group_ids
        group_ids[query_set_idx] = query_group_ids

        # Skip batches that do not have correct size
        # as this can raise errors from BN statistics calculation
        if images.shape[0] % algorithm.support_size != 0:
            continue

        # Shuffle the support set examples
        # We assume the support and query have the same size at this step
        # We will later take only one domain for the query examples
        query_set_idx = [e for e in range(images.shape[0]) if (
            e // algorithm.support_size) % 2 == 1]
        support_set_idx = [e for e in range(images.shape[0]) if (
            e // algorithm.support_size) % 2 == 0]
        if loader.batch_sampler.random_rng:
            support_set_idx_shuffled = loader.batch_sampler.random_rng.sample(
                support_set_idx, images.shape[0] // 2)
        else:
            support_set_idx_shuffled = random.sample(
                support_set_idx, images.shape[0] // 2)

        labels[support_set_idx] = labels[support_set_idx_shuffled]
        images[support_set_idx] = images[support_set_idx_shuffled]
        group_ids[support_set_idx] = group_ids[support_set_idx_shuffled]
        # Keep only the relevant part of labels - one query domain

        labels = labels[query_set_idx][:algorithm.support_size]

        # Forward
        if train:
            logits, batch_stats = algorithm.learn(images, labels, group_ids)
        else:
            logits = algorithm.predict(
                images, group_ids=group_ids, train=train)

            # During evaluation we should not reuse BN statistics updates
            # across tasks, so reset them to the statistics from training
            if 'ARM_BN' in str(algorithm) or 'CXDA' in str(algorithm):
                algorithm.model.load_state_dict(model_state_dict)

        epoch_labels.append(labels.to('cpu').clone().detach())
        epoch_logits.append(logits.to('cpu').clone().detach())
        epoch_group_ids.append(group_ids.to('cpu').clone().detach())

    return torch.cat(epoch_logits), torch.cat(epoch_labels), torch.cat(epoch_group_ids)


def train(args):

    # Get data
    train_loader, _, val_loader, _ = data.get_loaders(args)
    args.n_groups = train_loader.dataset.n_groups

    algorithm = init_algorithm(args, train_loader.dataset)
    saver = utils.Saver(algorithm, args.device, args.ckpt_dir)

    # Train loop
    best_avg_acc = 0

    stats_per_epoch = []

    for epoch in trange(args.num_epochs):
        epoch_logits, epoch_labels, epoch_group_ids = run_epoch(
            algorithm, train_loader, train=True, progress_bar=args.progress_bar)

        if epoch % args.epochs_per_eval == 0:
            stats = eval_latent_tasks(
                args, algorithm, val_loader, epoch, split='val')

            # Track early stopping values with respect to the average case.
            early_stop_metric = 'val/average_acc'

            if stats[early_stop_metric] > best_avg_acc:
                best_avg_acc = stats[early_stop_metric]
                saver.save(epoch, is_best=True)

            print(f"\nEpoch: ", epoch, "\nAverage Acc: ",
                  stats['val/average_acc'])

            # Also track the training accuracy
            train_preds = np.argmax(epoch_logits, axis=1)
            train_accuracy = np.mean((train_preds == epoch_labels).numpy())

            stats['train/average_acc'] = train_accuracy

            stats_per_epoch.append(stats)

    return stats_per_epoch


def eval_latent_tasks(args, algorithm, loader, epoch=None, split='val'):
    """
        Evaluate model on tasks with latent domains

        Each task is a latent domain adaptation problem where there are several domains.
        The domains are not labelled within the task.
    """

    start_time = datetime.now()

    algorithm.eval()

    # We do 5 passes over the evaluation set of examples

    # We sample query examples from 5 domains
    # (together with support examples from the same 5 domains
    # - we sample one query and one support example in the iterator
    # and then combine them into tasks)
    # and then keep query examples from only one domain
    # as we focus on adapting to one domain rather than mixture

    # 5 passes correspond to one full pass over the evaluation set
    # if we did not discard the query examples from unused domains
    num_eval_repetitions = 5
    seed_rng = random.Random(42)
    seeds = [seed_rng.randint(0, 999999) for e in range(num_eval_repetitions)]
    all_logits = []
    all_labels = []
    for seed in seeds:
        # We need to be careful about how we use random seed
        # generators to ensure same validation tasks across epochs and runs
        # and same test tasks across runs
        loader.dataset.rng = random.Random(seed)
        if loader.batch_sampler.eval:
            # Reset the rng
            loader.batch_sampler.rng = np.random.RandomState(seed + 1000)
            loader.batch_sampler.random_rng = random.Random(seed + 2000)

        logits, labels, _ = run_epoch(
            algorithm, loader, train=False, progress_bar=False)
        all_logits.append(logits)
        all_labels.append(labels)

    logits = torch.cat(all_logits)
    labels = torch.cat(all_labels)

    loader.dataset.rng = None

    task_size = args.support_size

    num_latent_tasks = logits.shape[0] // task_size
    accuracies = np.zeros(num_latent_tasks)
    num_examples = np.zeros(num_latent_tasks)

    # Calculate statistics across tasks
    for task_idx in range(num_latent_tasks):
        current_logits = logits[task_idx *
                                task_size: (task_idx + 1) * task_size, :]
        current_labels = labels[task_idx *
                                task_size: (task_idx + 1) * task_size]

        preds = np.argmax(current_logits, axis=1)

        # Evaluate
        accuracy = np.mean((preds == current_labels).numpy())

        num_examples[task_idx] = len(current_labels)
        accuracies[task_idx] = accuracy

    # Log worst, average and empirical accuracy
    worst_case_acc = np.amin(accuracies)
    worst_case_task_size = num_examples[np.argmin(accuracies)]

    num_examples = np.array(num_examples)
    props = num_examples / num_examples.sum()
    empirical_case_acc = accuracies.dot(props)
    average_case_acc = np.mean(accuracies)
    average_case_acc_std = np.std(accuracies)

    total_size = num_examples.sum()

    end_time = datetime.now()
    eval_time = (end_time - start_time).total_seconds()

    stats = {
        f'{split}/worst_case_acc': worst_case_acc,
        f'{split}/worst_case_task_size': worst_case_task_size,
        f'{split}/average_acc': average_case_acc,
        f'{split}/average_acc_std': average_case_acc_std,
        f'{split}/total_size': total_size,
        f'{split}/num_tasks': num_latent_tasks,
        f'{split}/empirical_acc': empirical_case_acc,
        f'{split}/eval_time': eval_time
    }

    if epoch is not None:
        stats['epoch'] = epoch

    else:
        stats[f'{split}/accuracies'] = accuracies.tolist()

    return stats
