# /usr/bin/env python
# -*- coding: utf-8 -*-


import json as js
import os
import time
from os.path import join

import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import random
import numpy as np

import pickle
import random
from datetime import datetime

import torch
import argparse
# from wildtime import dataloader
import os
import sys
from torchvision import datasets, transforms
from PIL import Image
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter

from dataset.wrapper import get_dataset
from model.wrapper import get_model
from online.models.wrapper import get_classifier_alg
from online.estimator.wrapper import get_estimator_alg
from online.utils.risk import *
from utils.argparser import argparser
from online.estimator.get_weights import weights_estimation

from utils.logger import MyLogger

from utils.tools import Timer

timer = Timer()

import warnings

warnings.filterwarnings('ignore')

from model.linear import Linear, Logistic
from model.cnn import CNN


def evaluation(model, data, label):
    pred_label = model.predict(data)
    correct_num = (label == pred_label).sum().item()
    acc = correct_num / len(label)
    error = 1. - acc

    return error, acc

def offline_train(cfgs, info, train_dataset, writer, logger, device):
    model_name = cfgs['Model'].get('type', 'Linear')
    print('Model structure: {}'.format(model_name))
    
    if model_name == 'Linear':
        model = Linear(
            input_dim=info['dim'],
            output_dim=info['cls_num'],
            R=cfgs['Model']['Classifier']['kwargs']['R'] / 2
        )
    elif model_name == 'CNN':
        model = CNN(
            num_input_channels=info['channel_num'],
            hid_dim=info['dim'],
            num_classes=info['cls_num']
        )
    else:
        raise NotImplementedError
    model = model.to(device)

    optimizer = optim.SGD(
        model.parameters(),
        lr=cfgs['Model']['Classifier']['kwargs']['lr'],
    )
    criterion = nn.CrossEntropyLoss()

    model.train()
    train_sampler = torch.utils.data.RandomSampler(train_dataset)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,  # pass the dataset to the dataloader
        batch_size=cfgs['Model']['Classifier']['kwargs']['source_batch_size'],  # a large batch size helps with the learning
        sampler=train_sampler,  # shuffling is important!
        drop_last=True,
        num_workers=0)  # apply transformations to the input images
    
    n_total = len(train_dataloader)
    for epoch in range(cfgs['Model']['Classifier']['kwargs']['init_erm']):
        num_iter, correct_num, total_loss, last_loss = 0, 0, 0.0, 0.0
        for step, batch in enumerate(
            tqdm(train_dataloader, desc='Train | epoch:{} | loss:{}'.format(epoch + 1, last_loss))):
            # if distributed:
            #     dist.barrier()
            x, y, _ = batch
            x = x.float().to(device)
            y = y.long().to(device)
            
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion(y_hat, y)
            last_loss = torch.mean(loss.detach()).item()
            total_loss += last_loss
            loss.backward()
            optimizer.step()
            num_iter += 1
            cur_step = epoch * n_total + step

            # _, predicted = y_hat.max(1)
            # real_prec = (y == real_predicted).sum().data[0] / batch_size
            # correct_num += (y == predicted).sum().item()
            
            if cur_step % cfgs['Model']['Classifier']['kwargs']['display_iter'] == 0:
                writer.add_scalar('Train/lr', optimizer.state_dict()['param_groups'][0]['lr'], cur_step)
                writer.add_scalar('Train/loss', last_loss, cur_step)
                writer.add_scalar('Train/acc',
                                  correct_num / (step + 1) / cfgs['Model']['Classifier']['kwargs']['source_batch_size'], cur_step)

        output = model(torch.from_numpy(train_dataset.X_bak).float().to(device))
        _, pred = output.max(1)
        correct_num = (torch.from_numpy(train_dataset.label).long() == pred).sum().item()
        avg_prec = correct_num / len(train_dataset)
        avg_loss = total_loss / len(train_dataloader)

        logger.info(f"Train: epoch: {epoch + 1:>02}, cur_step: {cur_step}, loss: {avg_loss:.5f}, acc: {avg_prec:.5f}")

    if not os.path.exists('./offline_model'):
        os.makedirs('./offline_model')

    dt = datetime.now().strftime('%m%d_%H%M%S')
    torch.save(model.state_dict(), './offline_model/offline_model_{}_{}.pth'.format(epoch + 1, dt))
    
    return model
