import argparse
import numpy as np
import numpy.lib as npl
import torch
from torch import nn, optim, autograd
import random

from load import get_years, YEARS
from utils import *
from densratio import densratio

parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--n_restarts', type=int, default=1)
parser.add_argument('--penalty_anneal_iters', type=int, default=1)
parser.add_argument('--steps', type=int, default=500)
parser.add_argument('--plot', action='store_true')
parser.add_argument('--save', type=str, default='')
parser.add_argument('--train_envs', type=str, default='2014,2015,2016')
parser.add_argument('--test_envs', type=str, default='2017,2018')
parser.add_argument('--psel', type=float, default=0.5)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.5)
parser.add_argument('--BATCH_SIZE', type=str, default=100)
flags = parser.parse_args()

train_env_ids = [int(s.strip()) for s in flags.train_envs.split(',')]
if flags.test_envs:
    test_env_ids = [int(s.strip()) for s in flags.test_envs.split(',')]
else:
    test_env_ids = npl.setxor1d(YEARS, train_env_ids)


print('Flags:')
for k, v in sorted(vars(flags).items()):
    print("    {}: {}".format(k, v))

def mix_up(environments, flags):
    all_data = []
    for k, e in enumerate(environments):
        x = e['images']
        y = e['labels']
        all_data.append((x, y, k))
        #print(all_data)

    x_mixed = []
    y_mixed = []

    for (x_i, y_i, e_i) in all_data:
        indices_i = list(range(x_i.shape[0]))
        for (x_j, y_j, e_j) in all_data:
            indices_j = list(range(x_j.shape[0]))

            for idx_i, idx_j in zip(indices_i, indices_j):
                #print("x_i[idx_i]",x_i[idx_i])
                #print("y_i[idx_i]",y_i[idx_i])
                distance = torch.sqrt((y_i[idx_i] - y_j[idx_j]) ** 2).item()
                if e_i == e_j and distance > .1:
                    lambda_mix = torch.distributions.Beta(flags.alpha, flags.beta).sample().item()
                    x_lambda= lambda_mix * x_i[idx_i] + (1 - lambda_mix) * x_j[idx_j]
                    y_lambda = lambda_mix * y_i[idx_i] + (1 - lambda_mix) * y_j[idx_j]
                    x_mixed.append(x_lambda.unsqueeze(0)) 
                    y_mixed.append(y_lambda.unsqueeze(0))

                elif e_i != e_j and distance <= .1:
                    lambda_mix = torch.distributions.Beta(flags.alpha, flags.beta).sample().item()
                    x_lambda = lambda_mix * x_i[idx_i] + (1 - lambda_mix) * x_j[idx_j]
                    y_lambda = lambda_mix * y_i[idx_i] + (1 - lambda_mix) * y_j[idx_j]
                    x_mixed.append(x_lambda.unsqueeze(0))
                    y_mixed.append(y_lambda.unsqueeze(0))

                indices_i.remove(idx_i)
                indices_j.remove(idx_j)
                if not indices_i or not indices_j:
                    break

    x_update = torch.cat(x_mixed, dim=0)
    y_update = torch.cat(y_mixed, dim=0)
    print("x_update",x_update.shape)
    print("y_update",y_update.shape)

    return x_update, y_update

    
def whiten(x):
    with torch.no_grad():
        x -= x.mean()
        x /= x.std()
    return x

def mean_nll(logits,y):
    loss=nn.MSELoss(reduction='mean')
    return loss(logits, y)

def mean_accuracy(logits, y):
    preds = (logits > 0.).float()
    return ((preds - y).abs()).float().mean()

class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.input_size = input_size
        lin1 = nn.Linear(input_size, flags.hidden_dim)
        lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
        lin3 = nn.Linear(flags.hidden_dim, 1)
        # for lin in [lin1, lin2, lin3]:
        #     nn.init.xavier_uniform_(lin.weight)
        #     nn.init.zeros_(lin.bias)
        self._main = nn.Sequential(
            lin1, nn.ReLU(True), #nn.Tanh(), #nn.ReLU(True),
            nn.Dropout(),
            lin2, nn.ReLU(True), #nn.Tanh(), #nn.ReLU(True),
            nn.Dropout(),
            lin3)
        
    def forward(self, x):
        x = x.view(x.shape[0], self.input_size)
        out = self._main(x)
        return out

def Q_CRIC(env):
    Q_num=0
    Q_phi=0
    # y_e_E=[]
    # q_ee=[]
    # dr_co_list=[]
    # err=torch.zeros(len(env[]))
    assert type(env)==list
    for i in range(len(env)):
        #x, _ = env[i]
        train_e=env[i]
        y = train_e['logits']
        x = train_e['images']
        q_e=torch.mean(y)
        # q_ee.append(q_e)
        # for j in env[-i]:
        for j in range(len(env)):
            if j == i:
                continue
            #x_t, _ = env[j]
            train_t=env[j]
            y_t=train_t['logits']
            x_t=train_t['images']
            dr_co=densratio(x.numpy(),x_t.numpy())

            w_co=torch.from_numpy(dr_co.compute_density_ratio(x_t.numpy())) 
            y_t = w_co*train_t['logits']
            q_e_co=torch.mean(y_t)  #calculate the weighted q
        
            Q_num=Q_num+(q_e_co-q_e)**2
            Q_num=Q_num.detach()
      
    return Q_num

final_train_accs = []
final_test_accs = []
logs = []
for restart in range(flags.n_restarts):
    print("Restart", restart)

    train_envs = get_years(train_env_ids)
    #print("train_envs",train_envs)
    test_envs = get_years(test_env_ids)
    # preprocess
    for e in train_envs + test_envs:
        e['images'], e['labels'] = mix_up(train_envs, flags)
        e['images'] = whiten(e['images'])
        e['labels'] = whiten(e['labels'])

    # init
    logger = Logger()
    mlp = MLP(train_envs[0]['images'].shape[1])
    optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)

    print_env_info(train_envs, test_envs)

    pretty_print('step', 'train nll', 'train acc', 'test acc')

    for step in range(flags.steps):
        for env in train_envs + test_envs:
            assert not torch.isnan(env['images']).any()
            env['logits'] = mlp(env['images'])
            env['nll'] = mean_nll(env['logits'], env['labels'])
            env['acc'] = mean_accuracy(env['logits'], env['labels'])


        train_nll = torch.stack([e['nll'] for e in train_envs]).mean()
        train_acc = torch.stack([e['acc'] for e in train_envs]).mean()

        weight_norm = torch.tensor(0.)
        for w in mlp.parameters():
            weight_norm += w.norm().pow(2)

        loss = train_nll.clone()
        loss += flags.l2_regularizer_weight * weight_norm

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logger.log('train_nll', train_nll)
        logger.log('train_acc', train_acc)
        logger.log('test_acc', [e['acc'] for e in test_envs])
        logger.log('losses', [e['nll'] for e in train_envs])

        if step % 100 == 0:
            print_stats(step, logger)

    final_train_accs.append(np.mean(logger['train_acc'][-50:]))
    final_test_accs.append(np.mean(logger['test_acc'][-50:]))
    print('Final train acc (mean/std across restarts so far):')
    print(np.mean(final_train_accs), np.std(final_train_accs))
    print('Final test acc (mean/std across restarts so far):')
    print(np.mean(final_test_accs), np.std(final_test_accs))

    logs.append(logger)

    if flags.plot:
        plot(logger)

#if flags.save:
#    save(logs, 'results/%s_%s_%s' % (flags.save, ','.join([str(e) for e in train_env_ids]), ','.join([str(e) for e in test_env_ids])))
Q = Q_CRIC(train_envs)
Q_t = Q_CRIC(test_envs)
print("Q:", Q)
print("Q_t:", Q_t)
