import os
import sys
sys.path.append('.')

import torch
import numpy as np

import argparse
from time import time
from datetime import datetime
from tqdm import tqdm
from math import *

import utils
import data_loader
import models
import test_eval

# parse options
parser = argparse.ArgumentParser()

parser.add_argument('--data-dir'        , type=str  , default='~'            )
parser.add_argument('--save'            , type=str  , default='cifar10_fsgld')

parser.add_argument('--num-class'       , type=int  , default=10             )
parser.add_argument('--gpu'             , type=int  , default=0              )
parser.add_argument('--seed'            , type=int  , default=None           )

parser.add_argument('--epoch'           , type=int  , default=200            )
parser.add_argument('--batch-size'      , type=int  , default=128            )
parser.add_argument('--lr0'             , type=float, default=0.5            )
parser.add_argument('--decay-scheme'    , type=str  , default='cyclical'     )
parser.add_argument('--lr-end'          , type=float, default=0              )
parser.add_argument('--temperature'     , type=float, default=1e-4           )
parser.add_argument('--sigma'           , type=float, default=0.001           )


args = parser.parse_args()

# setup GPU
utils.GPU_setup(args.gpu,args.seed)

# load data
trainloader,testloader = data_loader.cifar(args,num_classes=args.num_class)
oodloader = data_loader.svhn(args)

# build model
net = models.ResNet18(num_classes=args.num_class).to(args.gpu)

# setup training
num_batch = 50000/args.batch_size+1
M = 4 # number of cycles
T = args.epoch*num_batch # total number of iterations

criterion = torch.nn.CrossEntropyLoss().to(args.gpu)
opt = torch.optim.SGD(net.parameters(),lr=args.lr0,weight_decay=5e-4)


def noise(net,coeff):
    _noise = 0
    for param in net.parameters():
        _noise += torch.sum(param*torch.randn_like(param.data)*coeff)
    return _noise



# training at each epoch
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss,correct = 0,0

    for batch_idx,(inputs,targets) in tqdm(enumerate(trainloader)):
        inputs,targets = inputs.to(args.gpu),targets.to(args.gpu)

        lr = utils.lr_decay(args,opt,epoch,batch_idx,num_batch,T,M)

        # Store perturbations and add to parameters
        perturbations = []
        with torch.no_grad():
            for param in net.parameters():
                perturb = torch.randn_like(param) * args.sigma
                perturbations.append(perturb)
                param.add_(perturb)

        # Compute gradient at perturbed position
        outputs = net(inputs)
        noise_coeff = sqrt(2/lr/50000*args.temperature)
        loss = criterion(outputs,targets)+noise(net,noise_coeff)

        opt.zero_grad()
        loss.backward()

        # Remove perturbation
        with torch.no_grad():
            for param, perturb in zip(net.parameters(), perturbations):
                param.sub_(perturb)

        opt.step()

        # For logging
        outputs,loss = outputs.detach(),loss.detach()
        train_loss += loss.data.item()
        _,predicted = torch.max(outputs.data,1)
        correct += predicted.eq(targets.data).sum().item()

    print('Loss: %.3f | ACC: %.3f%% (%d/50000)' % (train_loss/num_batch,100.*correct/50000,correct))

# training loop
print('==> Training...')
_time = datetime.now()
if args.seed is not None:
    path = f'.checkpoints/{args.save}_seed{args.seed}_{_time.year}_{_time.month}_{_time.day}'
else:
    path = f'.checkpoints/{args.save}_{_time.year}_{_time.month}_{_time.day}'
os.system(f'mkdir -p {path}')
w_list = []

_time = time()

for epoch in range(args.epoch):
    train(epoch)

    if (epoch%50)+1>46:
        acc = test_eval.test(args.gpu, net, testloader, oodloader)
        w_list.append(utils.save_sample(net, f'{path}/{epoch}.pt'))

    else:
        # acc = test_eval.test(args.gpu,net,testloader)
        acc = 0

# report time usage
minute = (time() - _time) / 60
if minute<=60:
    print(f'Training finished in {minute:.1f} min.')
else:
    print(f'Training finished in {minute/60:.1f} h.')

# final testing
print('\n==> Final Testing...')

# Setup output file for test results
if args.seed is not None:
    result_file = f'{path}/test_results_seed{args.seed}.txt'
else:
    result_file = f'{path}/test_results.txt'

# Redirect output to file
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush()
    def flush(self):
        for f in self.files:
            f.flush()

output_file = open(result_file, 'w')
original_stdout = sys.stdout
sys.stdout = Tee(sys.stdout, output_file)

# Write experiment configuration
print("="*50)
print("EXPERIMENT CONFIGURATION")
print("="*50)
print(f"Dataset: CIFAR-{args.num_class}")
print(f"Save path: {args.save}")
print(f"Seed: {args.seed}")
print(f"Epochs: {args.epoch}")
print(f"Batch size: {args.batch_size}")
print(f"Initial LR: {args.lr0}")
print(f"Sigma: {args.sigma}")
print(f"Temperature: {args.temperature}")
print(f"Decay scheme: {args.decay_scheme}")
print("="*50)
print()

test_eval.multi_test(args.gpu, net, w_list, testloader, oodloader)

sys.stdout = original_stdout
output_file.close()

print(f'\nResults saved to: {result_file}')
