import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import numpy as np
import training.utils as utils
from torch.utils.data import DataLoader
from data_utils import EnvSampler


def train_loop(train_loaders, model, opt, ep, args, train_ebd):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = 0
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        # work on each batch
        # sample from the two env equally
        model['ebd'].train()
        model['clf_all'].train()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))
        x_0 = model['ebd'](batch_0['X'])
        y_0 = batch_0['Y']
        x_1 = model['ebd'](batch_1['X'])
        y_1 = batch_1['Y']

        if not train_ebd:
            x_0 = x_0.detach()
            x_1 = x_1.detach()

        acc_0, loss_0 = model['clf_all'](x_0, y_0, return_pred=False,
                                               grad_penalty=False)

        acc_1, loss_1 = model['clf_all'](x_1, y_1, return_pred=False,
                                               grad_penalty=False)

        loss = (loss_0 + loss_1) / 2.0
        acc = (acc_0 + acc_1) / 2.0

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_loop(test_loader, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred = [], []

    if att_idx_dict is not None:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])

        y = batch['Y']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)

        if att_idx_dict is not None:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def erm(train_data, test_data, model, opt, args, train_ebd=True):
    train_loaders = []

    # use balance sampling
    if args.balance:
        pos_idx = []
        neg_idx = []
        for i, y in zip(train_data.envs[0]['idx_list'],
                        train_data.get_all_y(0)):
            if y < 0.5:
                neg_idx.append(i)
            else:
                pos_idx.append(i)

        train_loaders.append(DataLoader(
            train_data,
            sampler=EnvSampler(args.num_batches, args.batch_size, 0, pos_idx),
        num_workers=10))

        train_loaders.append(DataLoader(
            train_data,
            sampler=EnvSampler(args.num_batches, args.batch_size, 0, neg_idx),
        num_workers=10))
    else:
        for i in range(2):
            train_loaders.append(DataLoader(
                train_data,
                # sampler=EnvSampler(args.num_batches, args.batch_size, i,
                sampler=EnvSampler(args.num_batches, args.batch_size, 0,
                                   train_data.envs[0]['idx_list']),
            num_workers=10))

    test_loaders = []
    for i in range(4):
        test_loaders.append(DataLoader(
            test_data,
            sampler=EnvSampler(-1, args.batch_size, i,
                               test_data.envs[i]['idx_list']),
        num_workers=10))

    # start training
    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_loop(train_loaders, model, opt, ep, args, train_ebd)

        with torch.no_grad():
            # validation
            # val_res = test_loop(test_loaders[2], model, ep, args,
            # val_res = test_loop(test_loaders[0], model, ep, args)
            val_res = test_loop(test_loaders[0], model, ep, args, None)

        utils.print_res(train_res, val_res, ep)

        if min(train_res['acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loaders[3], model, ep, args,
                         test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    return {
        'train': train_res,
        'val': val_res,
        'test': test_res,
    }
