import warnings
warnings.filterwarnings("ignore")
import random
import os
import math
import numpy as np
import pandas as pd
from sklearn import preprocessing
import logging
from datetime import datetime
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, RMSprop
from torchvision import datasets, transforms
dtype = torch.cuda.FloatTensor
from datetime import datetime
import time
import argparse

from src.components import *
from src.optimizers import *
from src.utils import *#log, import_and_preprocess_data, run_sgd_map_model
from src.convergence_criteria import *

import pickle

# create the log file 

folder_name = 'logs_conv'

if not os.path.exists(folder_name):
    os.makedirs(folder_name)

# dd/mm/YY H:M:S
now = datetime.now()
dt_string = now.strftime("%d_%m_%Y_%H_%M_%S")
file_folder = folder_name+'/'+dt_string
if not os.path.exists(file_folder):
    os.mkdir(file_folder)
logger = log(file=dt_string+"/"+dt_string+".logs", path_to_folder = folder_name+'/', folder_name = folder_name)

# Construct the argument parser
ap = argparse.ArgumentParser()

# Add the arguments to the parser
ap.add_argument("--epochs", required=False, default = 400, type=int,
   help="Batch size for model")
ap.add_argument("--map_init", required=False, default = 'no', type=str,
   help="Whether or not we initialize the model with the MAP solution")
ap.add_argument("--lr", required=False, default=1e-2, type=float,
   help="constant learning rate for model")
ap.add_argument("--dataset", required=False, default = 'CIFAR10', type=str,
   help="choose dataset between wine and real_estate")
ap.add_argument("--grouping", required=False, default='permuted', type=str,
   help="Choose how to assign the groups. Randomly, Permuted, By_layer, None")
ap.add_argument("--dropout_prob", required=False, default=0.1 ,type=float,
   help="Choose the dropout probability for the model")
ap.add_argument("--vi_batch_size", required=False, default=16, type=int,
   help="Choose the vi_batch_size for the model")
ap.add_argument("--evaluation", required=False,default='yes', type=str,
   help="Choose whether or not to evaluate the model")
ap.add_argument("--conv", required=False, default='no',type=str,
   help="Choose whether or not to evaluate the model for convergence")
ap.add_argument("--gpu", required=False, default = '0', type=str,
   help="Choose GPU to use")
ap.add_argument("--batch_size", required=False, default = 128, type=int,
   help="Choose batch_size for the dataset")
ap.add_argument("--groups", required=False, default = 'all', type=str,
   help="Choose groups for the parameters")
ap.add_argument("--seed", required=False, default = 2, type=int,
   help="Choose seed number for the code")
ap.add_argument("--opt", required=False, default = 'psgld', type=str,
   help="Choose optimizer for the training of the model")

args, leftovers = ap.parse_known_args()

for arg in vars(args):
    logger.info("{} = {}".format(arg, getattr(args, arg)))
      
#set the seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

## This is where the fun begins
if args.gpu=='cpu':
    device = 'cpu'
else:
    device = 'cuda:'+args.gpu

#https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627
# dev = torch.device('cpu')
dev = torch.device(device)
# dev = torch.device('cuda:0')

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.CIFAR10('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
valset = datasets.CIFAR10('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)

trainset = [(x.to(dev), torch.tensor(y, device=dev)) for x,y in trainset]
valset = [(x.to(dev), torch.tensor(y, device=dev)) for x,y in valset]

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True)
N = len(trainset) #for training

dataiter = iter(testloader)
images, labels = dataiter.next()

print(images.shape)
print(labels.shape)

###### First we run the model with SGD in order to find the MAP solution. Then we run it again with SGLD
# in order to find the posterior distribution.

if args.map_init=='yes':
    run_sgd_map_model(X,Y)
    
### Here we run the main function. First we choose the arguements


sgd_model = BayesianResNet20(**pickle.load(open("./resnet20_sgd_model_params.pickle", "rb")))
sgd_model.load_state_dict(torch.load("./sgd_resnet20_svhn_map.pt", map_location=dev))

if args.groups=='all':
    num_stoch_params = 0
    for param in sgd_model.get_stochastic_params():
        param_size = 1
        for dim in param.shape:
            param_size *= dim
        num_stoch_params += param_size
    groups = num_stoch_params
    print("max_groups:", groups)
    logger.info("REAL max_groups = {}".format(groups))
else:
    groups = int(args.groups)
    
if args.dropout_prob is None:
    drop = None
else:
    drop = float(args.dropout_prob)
  
print("dropout:", drop)

lvi_model_params = pickle.load(open("./resnet20_sgd_model_params.pickle", "rb"))
lvi_model_params["group_by_layers"] = False
lvi_model_params["use_random_groups"] = False
lvi_model_params["use_permuted_groups"] = True
lvi_model_params["max_groups"] = groups
lvi_model_params["dropout_prob"] = drop#None
lvi_model_params["chain_length"] = 5000
lvi_model_params["prior_std"] = 0.3 #1.0

lvi_model_params["init_values"] = {k:v.theta_actual.data for k,v in sgd_model.tensor_dict.items()}
del sgd_model

lvi_model = BayesianResNet20(**lvi_model_params)

if args.opt=='sghmc':
    lvi_model.initialize_optimizer(
        update_determ=False, 
        update_stoch=True, 
        lr=args.lr,
        sgd=False, 
        sgld=False, 
        psgld=False,
        sghmc=True,
    )
    print('INSIDE SGMC')
else:
    lvi_model.initialize_optimizer(
        update_determ=False, 
        update_stoch=True, 
        lr=args.lr,
        sgd=False, 
        sgld=False, 
        psgld=True,
        sghmc=False,
    )
    print('INSIDE PSGLD/ELSE')

lvi_model = lvi_model.to(dev)
for n, t in lvi_model.tensor_dict.items():
    if isinstance(t, StochasticTensor):
        t.prior_dist.loc = t.prior_dist.loc.to(dev)
        t.prior_dist.scale = t.prior_dist.scale.to(dev)

print("Before initialization: {}".format(lvi_model.num_samples_per_group))
lvi_model.init_chains()
print("After initialization: {}".format(lvi_model.num_samples_per_group))

criterion = torch.nn.CrossEntropyLoss()  # loss function

if args.vi_batch_size is None:
    vi_batch = None
else:
    vi_batch = int(args.vi_batch_size)
print("vi_batch:",vi_batch)

print("learning rate:",args.lr)

total_acc = []
# start = time.time()
for i in range(args.epochs):
    losses = []
    cross_losses = []
    accuracy = []
    
    start = time.time()
    
    for images, labels in trainloader:
        inner_cross_losses = []
        inner_accuracy = []
        
#         images = images.view(images.shape[0], -1)
        
        if i < 1:
            loss, y_pred,_ = lvi_model.training_step(
                batch=(images, labels),
                N=N,
                deterministic_weights=True,
                vi_batch_size=vi_batch,
            )
        else:
            loss, y_pred,_ = lvi_model.training_step(
                batch=(images, labels),
                N=N,
                deterministic_weights=False,
                vi_batch_size=vi_batch,
            ) 

        losses.append(loss)
        
        with torch.no_grad():
            for j in range(y_pred.shape[0]):
                cross_loss = criterion(y_pred.squeeze(0)[j], labels).item()
                inner_cross_losses.append(cross_loss)
                inner_accuracy.append((torch.max(y_pred.squeeze(0)[j],-1).indices == labels).sum().item() / labels.size(0))
        accuracy.append(sum(inner_accuracy)/len(inner_accuracy))
        cross_losses.append(sum(inner_cross_losses)/len(inner_cross_losses))
        
    print("Iter {} / {}, Loss: {}, CrossEntropy: {}, Accuracy: {}".format(i+1, args.epochs, sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), sum(accuracy)/len(accuracy)))

    end = time.time()
    print('Elapsed time for the training:', end - start)
    if args.evaluation == 'yes':
        tmp_acc = evaluate(lvi_model, testloader, N=len(valset))
        total_acc.append(tmp_acc)
logger.info("training_time_per_epoch = {}".format(end - start))
logger.info("last_training_accuracy = {}".format(sum(accuracy)/len(accuracy)))
logger.info("last_training_CrossEntropy = {}".format(sum(cross_losses)/len(cross_losses)))
np.save(file_folder+'/dropout_lvi_acc_cifar10.npy', total_acc)

if args.evaluation == 'yes':

    print("EVALUATION with 100 samples -> Loss: {}, CrossEntropy: {}, Accuracy: {}".format(sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), tmp_acc))
    
    logger.info("EVALUATION with 100 samples -> Loss: {}, CrossEntropy: {}, Accuracy: {}".format(sum(losses)/len(losses), sum(cross_losses)/len(cross_losses), tmp_acc))

if args.conv=='yes':
# we save the weights
    chains = []
    for name, chain in lvi_model.get_chains().items():
        chains.append(chain.view(chain.shape[0], -1).detach().cpu())
    w = torch.cat(chains, dim=-1).numpy()
    w = pd.DataFrame(w)
    print("W Shape is:",w.shape)

# if  args.dropout_prob == None and args.vi_batch_size == None:
#     w.to_csv(file_folder+'/full_mean_field_lvi.csv',header=None)
# elif args.grouping != 'by_layer' and args.grouping != 'random' and args.grouping != 'permuted':
#     w.to_csv(file_folder+'/sgld.csv',header=None)
# else: 
# w.to_csv(file_folder+'/lvi.csv',header=None)
    
# if args.evaluation=='yes':
#     np.save(file_folder+'/mses.npy', mses)
#     np.save(file_folder+'/nlls.npy', nlls)

if args.conv=='yes':
    iac_time = calculate_IAC(w)
    ess_time = calculate_ESS(w)
    
    logger.info("IAC time = {}".format(iac_time))
    logger.info("ESS time = {}".format(ess_time))

#     iac_time1 = calculate_IAC(w[10000:])
#     ess_time1 = calculate_ESS(w[10000:])
    
#     logger.info("IAC time after 10000 samples = {}".format(iac_time1))
#     logger.info("ESS time after 10000 samples = {}".format(ess_time1))
