import argparse
import collections
import math
import time
import os

import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix

from torchsummary import summary
import torch_optimizer as optim2

import wandb


import geniter
import record
import Utils

import data_utils as du

def train_new(args,
            net,
          train_iter,
          valida_iter,
          optimizer,
          device,
          epochs,
          logger,
          save_path,
          early_stopping=True,
          early_num=20):
    loss = torch.nn.CrossEntropyLoss()

    loss_list = [100]
    early_epoch = 0

    net = net.to(device)
    print("training on ", device)
    start = time.time()
    train_loss_list = []
    train_acc_list = []

    valida_loss_list = []
    valida_acc_list = []

    train_total_loss = du.Averager()
    train_total_acc = du.Averager()

    max_val_acc = 0


    for epoch in range(epochs):
        train_acc_sum, n = 0.0, 0
        time_epoch = time.time()
        lr_adjust = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 15, eta_min=0.0, last_epoch=-1)

        train_loss = du.Averager()
        train_acc = du.Averager()
        for X, y in train_iter:

            # batch_count, train_l_sum = 0, 0
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y.long())


            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            train_loss.add(l.cpu().item())
            # print(y_hat.argmax(dim=1).shape)
            # print( (y_hat.argmax(dim=1) == y).float().mean().cpu().item() )
            # print( y.shape[0] )
            acc = (y_hat.argmax(dim=1) == y).float().mean().cpu().item()
            train_acc.add( v=acc,
                            n = y.shape[0])

            train_total_loss.add(l.cpu().item())
            train_total_acc.add( v=acc,
                            n = y.shape[0])

            # train_l_sum += l.cpu().item()
            # train_acc_sum += (y_hat.argmax(dim=1) == y).mean().cpu().item()
            # n += y.shape[0]
            # batch_count += 1

        lr_adjust.step()
        valida_acc, valida_loss, preds, gt = record.evaluate_accuracy_new(
            valida_iter, net, loss, device)
        # valida_acc, valida_loss = record.evaluate_accuracy(
        #     valida_iter, net, loss, device)
        loss_list.append(valida_loss)

        # train_loss_list.append(train_l_sum)  # / batch_count)
        # train_acc_list.append(train_acc_sum / n)
        train_loss_list.append(train_loss.item())
        train_acc_list.append(train_acc.item())
        valida_loss_list.append(valida_loss)
        valida_acc_list.append(valida_acc)

        wandb.log({'epoch': epoch, 
                "train/loss": train_loss.item(),
                "train/acc": train_acc.item(),
                "val/loss": valida_loss,
                "val/acc": valida_acc,
                })

        logger.info(
            'epoch %d, train loss %.6f, train acc %.3f, valida loss %.6f, valida acc %.3f, time %.1f sec'
            % (epoch , train_loss.item(), train_acc.item(),
               valida_loss, valida_acc, time.time() - time_epoch))

        # PATH = "./net_DBA.pt"

        sv_file = {
            'model': net.state_dict(),
            'args': args,
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }

        if (args.epoch_save is not None) and (epoch % args.epoch_save == 0):
            torch.save(sv_file,
                os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if early_stopping and loss_list[-2] < loss_list[-1]:
            if early_epoch == 0:
                # torch.save(net.state_dict(), PATH)
                torch.save(sv_file,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))
            early_epoch += 1
            loss_list[-1] = loss_list[-2]
            if early_epoch == early_num:
                sv_file = torch.load(os.path.join(save_path, 'epoch-best.pth'))
                net.load_state_dict(sv_file['model'])
                optimizer.load_state_dict(sv_file["optimizer"])
                # net.load_state_dict(torch.load(PATH))
                break
        else:
            early_epoch = 0

        if max_val_acc < valida_acc:
            max_val_acc = valida_acc
            torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth'))


    logger.info('epoch %d, loss %.4f, train acc %.3f, time %.1f sec'
          % (epoch , train_total_loss.item(), train_total_acc.item(),
             time.time() - start))

def load_best_model(save_path, net, optimizer):
    sv_file = torch.load(os.path.join(save_path, 'epoch-best.pth'))
    net.load_state_dict(sv_file['model'])
    optimizer.load_state_dict(sv_file["optimizer"])
    
    return net, optimizer


def train(net,
          train_iter,
          valida_iter,
          loss,
          optimizer,
          device,
          epochs,
          early_stopping=True,
          early_num=20):
    loss_list = [100]
    early_epoch = 0

    net = net.to(device)
    print("training on ", device)
    start = time.time()
    train_loss_list = []
    valida_loss_list = []
    train_acc_list = []
    valida_acc_list = []
    for epoch in range(epochs):
        train_acc_sum, n = 0.0, 0
        time_epoch = time.time()
        lr_adjust = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 15, eta_min=0.0, last_epoch=-1)
        for X, y in train_iter:

            batch_count, train_l_sum = 0, 0
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y.long())

            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        lr_adjust.step()
        valida_acc, valida_loss = record.evaluate_accuracy(
            valida_iter, net, loss, device)
        loss_list.append(valida_loss)

        train_loss_list.append(train_l_sum)  # / batch_count)
        train_acc_list.append(train_acc_sum / n)
        valida_loss_list.append(valida_loss)
        valida_acc_list.append(valida_acc)

        print(
            'epoch %d, train loss %.6f, train acc %.3f, valida loss %.6f, valida acc %.3f, time %.1f sec'
            % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,
               valida_loss, valida_acc, time.time() - time_epoch))

        PATH = "./net_DBA.pt"

        if early_stopping and loss_list[-2] < loss_list[-1]:
            if early_epoch == 0:
                torch.save(net.state_dict(), PATH)
            early_epoch += 1
            loss_list[-1] = loss_list[-2]
            if early_epoch == early_num:
                net.load_state_dict(torch.load(PATH))
                break
        else:
            early_epoch = 0

    print('epoch %d, loss %.4f, train acc %.3f, time %.1f sec'
          % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,
             time.time() - start))


