import os
import sys
import csv
import wget
import zipfile
import numpy as np
import pandas as pd
import torch
from torch import nn
import torchvision
from torchvision import transforms, datasets
from torch import distributions
from torch.utils.data import TensorDataset, DataLoader
import ssl

# from models import self_mask
from tqdm import tqdm


def endtoend_train(flow, batch_EM, nf_optimizer, loader, args):

    loss1 = 0.
    loss2 = 0.
    loss3 = 0.
    
    loss_func = nn.MSELoss(reduction='none')

    for x_dot, x_origin, mask in tqdm(loader):
        x_dot = x_dot.cuda()
        x_origin = x_origin.cuda()
        mask = mask.cuda()

        z, log_p = flow.log_prob(x_dot)
        loss1 += log_p.item() * x_dot.shape[0]

        log_p.backward()
        if args.grad_clip:
            torch.nn.utils.clip_grad_norm_(flow.parameters(), args.grad_clip)
        nf_optimizer.step()
        nf_optimizer.zero_grad()

        z_hat, new_prior = batch_EM.complete_gpu(z.detach(), mask, mode='train')
        flow.update_prior(new_prior)

        x_hat = flow.inverse(z_hat.detach())
        # x_hat = flow.inverse(z_hat)

        _, log_p = flow.log_prob(x_hat)
        mse_loss = torch.mean(loss_func(x_hat, x_origin) * (1 - mask))
        total_loss = log_p + args.alpha * mse_loss
        loss2 += args.alpha * mse_loss.item() * x_dot.shape[0]
        loss3 += log_p.item() * x_dot.shape[0]

        total_loss.backward()
        if args.grad_clip:
            torch.nn.utils.clip_grad_norm_(flow.parameters(), args.grad_clip)
        nf_optimizer.step()
        nf_optimizer.zero_grad()
        
    return loss1/len(loader.dataset), loss2/len(loader.dataset), loss3/len(loader.dataset)



def endtoend_train_superbatch(flow, batch_EM, nf_optimizer, nf_scheduler=None, loader=None, args=None):
    loss1 = 0.
    loss2 = 0.
    loss3 = 0.
    
    loss_func = nn.MSELoss(reduction='none')
    z_moving = None
    mask_moving = None
    flow.train()
    for x_dot, x_origin, mask in tqdm(loader, total=len(loader)):
        x_dot = x_dot.cuda()
        x_origin = x_origin.cuda()
        mask = mask.cuda()

        z, log_p = flow.log_prob(x_dot)

        loss1 += log_p.item() * x_dot.shape[0]

        log_p.backward()
        if args.grad_clip:
            torch.nn.utils.clip_grad_norm_(flow.parameters(), args.grad_clip)
        nf_optimizer.step()
        nf_optimizer.zero_grad()

        if z_moving is None:
            z_moving = z.detach()
        else:
            z_moving = torch.cat((z_moving, z.detach()), 0)
            if z_moving.shape[0] > args.super_size:
                z_moving = z_moving[-args.super_size:]

        if mask_moving is None:
            mask_moving = mask
        else:
            mask_moving = torch.cat((mask_moving, mask.detach()), 0)
            if mask_moving.shape[0] > args.super_size:
                mask_moving = mask_moving[-args.super_size:]

        z_hat_moving, new_prior = batch_EM.complete_gpu(z_moving, mask_moving, mode='train')
        flow.update_prior(new_prior)
        z_hat = z_hat_moving[-x_dot.shape[0]:]

        mask_moving = torch.zeros_like(mask_moving)

        x_hat = flow.inverse(z_hat.detach())
        _, log_p = flow.log_prob(x_hat)
        mse_loss = torch.mean(loss_func(x_hat, x_origin) * (1 - mask))
        total_loss = log_p + args.alpha * mse_loss

        loss2 += args.alpha * mse_loss.item() * x_dot.shape[0]
        loss3 += log_p.item() * x_dot.shape[0]

        total_loss.backward()
        if args.grad_clip:
            torch.nn.utils.clip_grad_norm_(flow.parameters(), args.grad_clip)
        nf_optimizer.step()
        nf_optimizer.zero_grad()

    if nf_scheduler:
        nf_scheduler.step()
        
    return loss1/len(loader.dataset), loss2/len(loader.dataset), loss3/len(loader.dataset)


def endtoend_test(flow, batch_EM, data_loader, args):
    mse_loss = 0.
    total_imputing = 0.
    loss = nn.MSELoss(reduction='none')

    with torch.no_grad():
        for x_dot, x_origin, mask in data_loader:
            x_dot = x_dot.cuda()
            x_origin = x_origin.cuda()
            mask = mask.cuda()

            z, _ = flow.log_prob(x_dot)
            z_hat = batch_EM.complete_gpu(z, mask, mode='test')
            x_hat = flow.inverse(z_hat)

            batch_loss = torch.sum(loss(torch.clamp(x_hat, min=0, max=1), x_origin) * mask)
            total_imputing += np.sum(mask.cpu().numpy())
            mse_loss += batch_loss.item()

    return mse_loss/total_imputing

def init_flow_model(num_neurons, num_layers, init_flow, data_shape):

    nets = lambda: nn.Sequential(nn.Linear(data_shape, num_neurons), nn.LeakyReLU(), 
                                 nn.Linear(num_neurons, num_neurons), nn.LeakyReLU(), 
                                 nn.Linear(num_neurons, num_neurons), nn.LeakyReLU(), 
                                 nn.Linear(num_neurons, data_shape), nn.Tanh())

    nett = lambda: nn.Sequential(nn.Linear(data_shape, num_neurons), nn.LeakyReLU(), nn.Linear(num_neurons, num_neurons), nn.LeakyReLU(),
        nn.Linear(num_neurons, num_neurons),  nn.LeakyReLU(), nn.Linear(num_neurons, data_shape))

    mask = []
    for idx in range(num_layers):
        msk = create_coupling_mask(data_shape)
        mask.append(msk)
        mask.append(1-msk)

    masks = torch.from_numpy(np.asarray(mask)).float()
    masks = masks.cuda()
    prior = distributions.MultivariateNormal(torch.zeros(data_shape).cuda(), torch.eye(data_shape).cuda())
    flow = init_flow(nets, nett, masks, prior)
    flow.cuda()

    return flow

def create_coupling_mask(shape):
    zeros = int(shape/2)
    ones = shape - zeros
    lst = []
    for i in range(shape):
        if zeros > 0 and ones > 0:
            if np.random.uniform() > .5:
                lst.append(0)
                zeros -= 1
            else:
                lst.append(1)
                ones -= 1
        elif zeros > 0:
            lst.append(0)
            zeros -= 1
        else:
            lst.append(1)
            ones -= 1
    return np.asarray(lst)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
