import torch
import torch.optim as optim
import torchvision
import random
import math
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from enum import IntEnum
from models_hd import *
from utils import *
from optimizers import *
from sgd_cam_hd import *
from adam_cam_hd import *
# from Pytorchtools import *
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def model_train_HD(data, args, log_func=None):
       
    train_dataset, val_dataset, test_dataset = data
        
    try:
        patience = args.patience
        model_type = args.model
        parallel = args.parallel
        cuda = args.cuda
        epochs = args.epochs
        lr = args.lr
        momentum = args.momentum
        seed = args.seed
        device = args.device
        weightDecay = args.weightDecay
        beta = args.beta
        method = args.method
    except:
        patience = args["patience"]
        model_type = args["model"]
        parallel = args["parallel"]
        cuda = args["cuda"]
        epochs = args["epochs"]
        lr = args["lr"]
        momentum = args["momentum"]
        seed = args["seed"]
        device = args["device"]
        weightDecay = args["weightDecay"]
        beta = args["hypergrad_lr"]
        method = args["method"]
        task = args["task"]
        lr_shedule = args["lr_schedule"]
        
    print("device:", device)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.set_device(0)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.enabled = True
        
    args["batch_size"] = int(args["batch_size"])
    print("model_train_HD: lr, batch_size", args["lr"], args["batch_size"])
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                               batch_size=args["batch_size"], 
                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                              batch_size=args["batch_size"], 
                              shuffle=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                              batch_size=args["batch_size"], 
                              shuffle=False)
    
    print("len(train_loader), len(valid_loader), len(test_loader)", len(train_loader), len(valid_loader), len(test_loader))
              
    model = set_model(args)
    params = model.parameters()
    if cuda:
        print("model.cuda")
        model = model.cuda()
        
    argsimizer = set_opt(model, args)
    
    time_list = []
    train_losses = []
    valid_losses = []
    avg_valid_accs = []
    avg_train_losses = []
    avg_valid_losses = [] 
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    iteration = 1
    gamma_list_list = []
    
    for epoch in range(1, epochs + 1):

        print("epoch:", epoch)
        
        if lr_shedule!=None:
            
            if args["method"]=="sgd_cam_hd":
                lr_ratio = lr_shedule(epoch)
                print("lr, argsimizer.lr, lr_global 1", lr, argsimizer.lr, argsimizer.lr_global)
                argsimizer.lr = lr_ratio * argsimizer.lr
                argsimizer.lr_global = lr_ratio * argsimizer.lr_global
                argsimizer.param_groups[0]['lr'] = lr_ratio * argsimizer.param_groups[0]['lr']
                argsimizer.lr_list = [lr_ratio * item for item in argsimizer.lr_list]
                print("lr, argsimizer.lr, lr_global 2", lr, argsimizer.lr, argsimizer.lr_global)
                print("argsimizer.lr_list", argsimizer.lr_list)
            else:
                lr = lr_shedule(epoch)
                argsimizer.param_groups[0]['lr'] = lr
            
        loss_epoch = 0
        alpha_epoch = 0
        for batch_id, (data, target) in enumerate(train_loader):
            data, target = Variable(data), Variable(target)
            if cuda:
                data, target = data.cuda(), target.cuda()
            argsimizer.zero_grad()

            if model_type == "FFNN":
                data = data.reshape(-1, 28*28).to(device)
            if model_type == "lstm":
                data = data.reshape(-1, 28, 28).to(device)
            if model_type == "AE":
                data = data.view(data.size(0), -1).to(device)

            output = model(data)

            if model_type == "AE" or model_type == "CAE":
                criterion = nn.MSELoss()
                loss = criterion(output, data) 
            else:
                loss = F.cross_entropy(output, target)
                
            loss.backward()            

            try:
                loss_opt, gamma_list = argsimizer.step()
            except:
                try:
                    argsimizer.step()
                except:
                    argsimizer.step(loss)
                    
            loss = loss.data.item() # [0]
            loss_epoch += loss
            try:
                alpha = argsimizer.param_groups[0]['lr']
            except:
                alpha = lr
            alpha_epoch += alpha
            iteration += 1
            train_losses.append(loss)

        try:
            print("len(gamma_list)", len(gamma_list))
            print("gamma_list", gamma_list, [gamma_list[0].item(), gamma_list[1].item()])
            gamma_list_list.append([gamma_list[0].item(), gamma_list[1].item()])
        except:
            0 
                     
        loss_epoch /= len(train_loader)
        alpha_epoch /= len(train_loader)
        valid_loss = 0
        val_loss = 0
        
        valid_acc, model = test_CLA(valid_loader, model, args)
                
        for data, target in valid_loader:
            data, target = Variable(data, volatile=True), Variable(target)
            if cuda:
                data, target = data.cuda(), target.cuda()
            if model_type == "lstm":
                data = data.reshape(-1, 28, 28) # .to(device)
            if model_type == "AE":
                data = data.view(data.size(0), -1)
              
            output = model(data)
            if model_type == "AE" or model_type == "CAE":
                criterion = nn.MSELoss()
                val_loss = criterion(output, data).data.item()
            else:
                val_loss = F.cross_entropy(output, target, size_average=True).data.item()
            valid_losses.append(val_loss)
        
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        avg_valid_accs.append(valid_acc)
        epoch_len = len(str(epochs))
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
    
    if task == "CLA":
        eva, model = test_CLA(test_loader, model, args)
    elif task == "regression":
        eva, model = test_REG(test_loader, model)
    elif task == "encoding":
        eva, model = test_AE(test_loader, model, model_type)
    else:
        print("undefined task!")
                       
    return eva, model, avg_train_losses, avg_valid_losses, avg_valid_accs, time_list, gamma_list_list

def set_model(args):
    
    model_type = args["model_type"]
    
    if model_type == 'logreg':
        model = LogReg(28 * 28, 10)
    elif model_type == 'mlp':
        model = MLP_2(28 * 28, 100, 10)
    elif model_type == 'FFNN':
        model = FFNN(args)
    elif model_type == 'AE':
        model = autoencoder(args)
    elif model_type == 'CAE':
        model = conv_autoencoder(args)
    elif model_type == 'lstm':
        model = LSTM(args)
    elif model_type == "lenet_5":
        model = LeNet_5(args)
    elif model_type == "L4_Net":
        model = L4_Net()
    elif model_type == 'vgg':
        model = vgg.vgg16_bn()
        if parallel:
            model.features = torch.nn.DataParallel(model.features)
    elif model_type == 'resnet_18': 
        print("model_type:resnet_18")
        # model = ResNet_0(ResidualBlock, [2, 2, 2, 2]) # .to(device)  # resnet50()
        model = ResNet(BasicBlock, [2, 2, 2, 2])
    elif model_type == 'resnet_20': 
        print("model_type:resnet_20")
        model = ResNet_0(ResidualBlock, [3, 3, 3]) # .to(device)  # resnet50() # BasicBlock
    elif model_type == 'resnet_34': 
        print("model_type:resnet_34")
        # model = ResNet(ResidualBlock, [3, 4, 6, 3]) # .to(device)  # resnet50()    
        model = ResNet(BasicBlock, [3, 4, 6, 3]) # .to(device)  # resnet50()    
    elif model_type == "RNNLM":
        model = LeNet_5(args)
    else:
        raise Exception('Unknown model: {}'.format(model))
    
    return model

def set_opt(model, args):
    
    method = args["method"] 
    lr = args["lr"]
    weightDecay = args["weightDecay"]
    beta = args["hypergrad_lr"]
    meta = args["meta"]
    
    try:
        nesterov = args["nesterov"]
        momentum = args["momentum"]
    except:
        nesterov = False
        momentum = 0
    
    if method == 'sgd':
        optimizer = SGD(model.parameters(), lr=lr, weight_decay=weightDecay, momentum=momentum, nesterov=nesterov)
        if meta == "L4":
            optimizer = L4(optimizer)
    elif method == 'sgd_hd':
        optimizer = SGD_HD(model.parameters(), args)
    elif method == 'sgdn':
        optimizer = SGD(model.parameters(), lr=lr, weight_decay=weightDecay, momentum=momentum, nesterov=True)
    elif method == 'sgd_cam_hd': 
        optimizer = SGD_CAM_HD(model.parameters(), args)
    elif method == 'adam':
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weightDecay)
        if meta == "L4":
            optimizer = L4(optimizer)
            print("optimizer with L4")
    elif method == 'adam_hd':
        optimizer = Adam_HD(model.parameters(), lr=lr, weight_decay=weightDecay, hypergrad_lr=beta)
    elif method == 'adam_cam_hd':
        optimizer = Adam_CAM_HD(model.parameters(), args)
    elif method == 'amsgrad':    
        optimizer = GAdam(model.parameters(), lr=1e-3, betas=(0.9, 0.99), amsgrad_decay=1e-4)
    elif method == 'adabound':
        optimizer = AdaBound(model.parameters(), lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
                 eps=1e-8, weight_decay=0, amsbound=False)
    elif method == 'radam':
        optimizer = RAdam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
    else:
        raise Exception('Unknown method: {}'.format(method))
        
    return optimizer

def test_CLA(test_loader, model, args):
    
    cuda = args["cuda"]
    model_type = args["model"]
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        
        if model_type == "lstm":
            data = data.reshape(-1, 28, 28) # .to(device)
        
        output = model(data)
        loss = F.nll_loss(output, target, size_average=True)
        test_loss += loss.data.item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
    test_loss /= len(test_loader.dataset)
    acc = 100. * np.float(correct) / len(test_loader.dataset)
    print("acc", acc)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    
    eva = acc
    
    return eva, model

def test_AE(test_loader, model, model_type):
    
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        
        L = len(test_loader)
        loss_ave = 0
        
        for data in test_loader:
            
            img, _ = data
            
            if model_type == "AE":
                img = img.view(img.size(0), -1)
                
            img = Variable(img) # .to(device)
            output = model(img)            
            loss = criterion(output, img).data.item()
            loss_ave = loss_ave + loss/L
    
    eva = loss_ave
    
    return eva, model

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            # print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        """
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        """
        self.val_loss_min = val_loss
