import torch
from torch.utils.data import Dataset
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
import numpy as np

import os

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from model import LeNet
from utils import *

#=============
# arguments
#=============
total_reps = 5
learning_rate = 0.005
max_iter = 100000+1
phase_end_iter = 50000
check_iter = 100
store_iter = 2000
device = 'cuda:0'#'cpu'
seed = 1

# heavy tail noise and grad clip parameters
heavy_tail_noise_alpha = 1.4
heavy_tail_noise_magnitude = 0.5
#gradient_clip = 5
model_weight_clip = 2



batch_size_train = 100
training_size = 1200
training_with_corrupted_label = 200
load_path_dataset = 'data/FashionMNIST_corrupted_dataset_'
model_weight_dir = 'checkpoints/LeNet_initialization'
save_path = 'checkpoints/REPS_SB_noise/'

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

#if not os.path.exists('data/'):
#    os.makedirs('data/')

if not os.path.exists(save_path):
    os.makedirs(save_path)

#================
# dataset
#================

# add transformer for FashionMNIST dataset
stats = {'mean': [0.5],'std': [0.5]}
trans = [
        transforms.ToTensor(),
        lambda t: t.type(torch.get_default_dtype()),
        transforms.Normalize(**stats)
        ]

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose(trans)
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.Compose(trans)
)

#===========================================================
# modify training data: load the ones generated by GD script
#===========================================================
np_data = np.load(open(load_path_dataset+'data.npy','rb'))
np_label = np.load(open(load_path_dataset+'label.npy','rb'))

training_data.data = torch.from_numpy( np_data ) #training_data.data[random_idx_list_for_training_data]
training_data.targets = torch.from_numpy( np_label ) #training_data.targets[random_idx_list_for_training_data]

#============
# data loader
#============
train_dataloader = DataLoader(training_data, batch_size= training_size, shuffle=False)
train_dataloader_small_batch = DataLoader(training_data, batch_size= batch_size_train, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=100, shuffle=False)

# load entire training dataset to gpu
X_train,Y_train = next(iter(train_dataloader))
X_train,Y_train = X_train.to(device), Y_train.to(device)

for idx_rep in range(total_reps):
    #=================
    # initialize model
    #=================
    net = LeNet()
    # load model weight
    checkpoint = torch.load(model_weight_dir+'/initialized_weight_'+str(idx_rep)+'.pth.tar')
    net.load_state_dict(checkpoint)
    del checkpoint
    # move to gpu
    net = net.to(device)

    loss_fn = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

    # optimization
    torch.backends.cudnn.benchmark = True
    train_loss_list = []
    train_acc_list = []
    valid_loss_list = []
    valid_acc_list = []
    for idx in range(max_iter):
        # first and second phases: true gradient descent direction
        optimizer.zero_grad()
        train_loss,train_acc = eval_loss_and_acc_on_batch(net,X_train,Y_train,loss_fn,
                                                          require_acc = True,
                                                          mode = 'train')
        train_loss.backward()
        if idx < phase_end_iter:
            # phase 1: evaluate lb and sb direction, then apply heavy-tailed noise
            lb_direction_dict = get_grads_dict(net)

            optimizer.zero_grad()
            for x_,y_ in train_dataloader_small_batch:
                x,y = x_.to(device),y_.to(device)
                break
            loss, _ = eval_loss_and_acc_on_batch(net,x,y,loss_fn, require_acc = False, mode = 'train')
            loss.backward()
            sb_direction_dict = get_grads_dict(net)

            # noise direction and multiplier
            noise_direction = get_dict_differnce(sb_direction_dict, lb_direction_dict)
            noise_size = np.random.pareto(heavy_tail_noise_alpha) * heavy_tail_noise_magnitude

            # modify gradient
            optimizer.zero_grad()
            modify_model_noise(net, lb_direction_dict, noise_direction, 1 + noise_size)

        # apply gradient clipping
        #if gradient_clip is not None and gradient_clip > 0:
            #torch.nn.utils.clip_grad_norm_(net.parameters(), gradient_clip )

        optimizer.step()

        # clip model weights after each step
        #clip_model_weights(model,args.model_weight_clip, 20)
        if model_weight_clip > 0.1:
            for p in net.parameters():
                p.data = torch.clamp(p.data, -model_weight_clip, model_weight_clip)



        # update train info
        train_loss_list.append( [idx,train_loss.item()]  )
        train_acc_list.append( [idx,train_acc] )

        # check on VALID SET
        if idx % check_iter == 0:
            eval_loss, eval_acc = eval_loss_and_acc_on_valid_set(net,test_dataloader,loss_fn,device = device)
            valid_loss_list.append([idx,eval_loss])
            valid_acc_list.append([idx,eval_acc])

            print('ITER '+str(idx))
            print('   TRAIN LOSS '+str(train_loss.item())+', ACC '+str(train_acc))
            print('   VALID LOSS '+str(eval_loss)+', ACC '+str(eval_acc))

        # store models and metrics
        if idx % store_iter == 0:
            torch.save( net.state_dict(), save_path+ '/REP_'+str(idx_rep)+'model_iter'+str(idx)+'.pth')
            np.save( open( save_path+ '/REP_'+str(idx_rep)+'train_loss_history.npy','wb' ),np.array(train_loss_list)   )
            np.save( open( save_path+ '/REP_'+str(idx_rep)+'train_acc_history.npy','wb' ),np.array(train_acc_list)   )
            np.save( open( save_path+ '/REP_'+str(idx_rep)+'valid_loss_history.npy','wb' ),np.array(valid_loss_list)   )
            np.save( open( save_path+ '/REP_'+str(idx_rep)+'valid_acc_history.npy','wb' ),np.array(valid_acc_list)   )
