import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import random
import numpy as np
import training.utils as utils
from torch.utils.data import DataLoader
from data_utils import EnvSampler


def compute_l2(XS, XQ):
    '''
        Compute the pairwise l2 distance
        @param XS (support x): support_size x ebd_dim
        @param XQ (support x): query_size x ebd_dim

        @return dist: query_size x support_size

    '''
    diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
    dist = torch.norm(diff, dim=2)

    return dist ** 2


def train_loop(train_loaders, model, opt, ep, args, train_ebd):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = 0
    for batches in zip(*train_loaders):
        # work on each batch
        # sample from the two env equally
        model['ebd'].train()
        model['clf_all'].train()

        batches = list(batches)

        random.shuffle(batches)
        batch_0 = batches[0]
        batch_1 = batches[1]

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        ebd_pos = model['ebd'](batch_0['X'])
        ebd_neg = model['ebd'](batch_1['X'])

        diff_pos_pos = compute_l2(ebd_pos, ebd_pos)
        diff_pos_neg = compute_l2(ebd_pos, ebd_neg)
        diff_neg_neg = compute_l2(ebd_neg, ebd_neg)

        loss = (
            torch.mean(torch.max(torch.zeros_like(diff_pos_pos),
                                diff_pos_pos - diff_pos_neg +
                                torch.ones_like(diff_pos_pos) *
                                 args.thres)))
        acc = 0.0

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_loop(test_loaders, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred = [], []

    if att_idx_dict is not None:
        idx = []

    for batch_0, batch_1 in zip(test_loaders[0], test_loaders[1]):
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        ebd_pos = model['ebd'](batch_0['X'])
        ebd_neg = model['ebd'](batch_1['X'])

        diff_pos_pos = compute_l2(ebd_pos, ebd_pos)
        diff_pos_neg = compute_l2(ebd_pos, ebd_neg)
        diff_neg_neg = compute_l2(ebd_neg, ebd_neg)

        loss = (
            torch.mean(torch.max(torch.zeros_like(diff_pos_pos),
                                diff_pos_pos - diff_pos_neg +
                                torch.ones_like(diff_pos_pos) *
                                 args.thres)))

        loss_list.append(loss.item())

    acc = 0.0
    loss = np.mean(np.array(loss_list))

    return {
        'acc': acc,
        'loss': loss,
    }


def get_balance_loader(data, env_id, args):
    loaders = []

    # use balance sampling
    y_dict = {}
    for i, y in zip(data.envs[env_id]['idx_list'],
                    data.get_all_y(env_id)):
        if y not in y_dict:
            y_dict[y] = []

        y_dict[y].append(i)

    for k, v in y_dict.items():
        loaders.append(DataLoader(
            data,
            sampler=EnvSampler(args.num_batches, args.batch_size, env_id, v),
        num_workers=10))

    return loaders


def metric_learning(train_data, test_data, model, opt, args, train_ebd=True):
    train_loaders = get_balance_loader(train_data, 0, args)
    val_loaders = get_balance_loader(test_data, 2, args)
    test_loaders = get_balance_loader(test_data, 3, args)


    # start training
    best_loss = 1000
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_loop(train_loaders, model, opt, ep, args, train_ebd)

        with torch.no_grad():
            # validation
            # val_res = test_loop(test_loaders[2], model, ep, args,
            # val_res = test_loop(test_loaders[0], model, ep, args)
            val_res = test_loop(val_loaders, model, ep, args, None)

        utils.print_res(train_res, val_res, ep)

        if val_res['loss'] < best_loss:
            best_loss = val_res['loss']
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loaders, model, ep, args,
                         test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    return {
        'train': train_res,
        'val': val_res,
        'test': test_res,
    }, None, None
