# /usr/bin/env python
# -*- coding: utf-8 -*-


import torch
import torch.nn as nn
from torch.optim import *
from torch.utils.data import DataLoader

import torch.nn.functional as F

import copy

from utils.tools import Timer

timer = Timer()

__all__ = ["Base", "CustomOGD"]


class Base(object):
    def __init__(self, cfgs=None, seed=None, **alg_kwargs):
        self.device = alg_kwargs["device"]
        self.model = copy.deepcopy(alg_kwargs['model']).to(self.device)
        
        self.cfgs = cfgs
        self.seed = seed
        self.criterion = None
        self.dataloader = DataLoader(alg_kwargs['dataset'],
                                     num_workers=0, batch_size=alg_kwargs['batch_size'],
                                     shuffle=True, pin_memory=True)
        self.alg_kwargs = alg_kwargs
        self.cache = []
        for batch_idx, (data, target, idx) in enumerate(self.dataloader):
            data, target, idx = data.to(self.device), target.to(self.device), idx
            self.cache.append((data, target, idx))
    
    def set_func(self, func):
        self.criterion = func
    
    def reinit(self, model):
        self.model = copy.deepcopy(model).to(self.device)
    
    @torch.no_grad()
    def predict(self, data):
        data = data.to(self.device)
        data = data.to(torch.float32)
        output = self.model(data)
        _, pred = output.max(1)
        
        return pred

    @torch.no_grad()
    def confidence(self, data):
        data = data.to(self.device)
        data = data.to(torch.float32)
        output = self.model(data)
        prob = F.softmax(output, dim=1)
        mask = prob < 0.75
        prob[mask] = 0

        return prob

class CustomOGD(Base):
    def __init__(self, cfgs=None, seed=None, **algo_kwargs):
        super(CustomOGD, self).__init__(cfgs=cfgs, seed=seed, **algo_kwargs)
        self.lr = algo_kwargs['stepsize']
        self.projection = algo_kwargs['projection']
        self.grad_clip = (self.cfgs is not None) and self.cfgs.get('grad_clip', False)
        self.optimizer = None
        self.estimate_result = {
            'D': [],
            'G': [],
        }
    
    def estimate_gd(self):
        _weights = self.model.get_weights()
        weights = torch.cat(tuple(_weights.values()), dim=0)
        D = torch.norm(weights)
        self.estimate_result['D'].append(D.item())
        try:
            _grads = self.model.get_grad()
            grads = torch.cat(tuple(_grads.values()), dim=0)
            G = torch.norm(grads)
            self.estimate_result['G'].append(G.item())
        except BaseException:
            pass
    
    def parameters_update(self):
        self.model.train()
        optimizer = Adam(self.model.linear.parameters(), lr=self.lr)
        cum_loss = 0.
        for batch_idx, (source_data, source_label, index) in enumerate(self.cache):
            output = self.model(source_data.float())
            loss = self.criterion(output.float(), source_label, index)
            cum_loss += loss
            
            optimizer.zero_grad()
            loss.backward()
            if self.grad_clip:
                nn.utils.clip_grad_norm_(self.model.parameters(), 1.0, norm_type=1.0)
            optimizer.step()
        
        if self.projection:
            self.model.project()
        
        return cum_loss
    
    def self_train(self, cfgs, target_data, criterion):
        self.model.train()
        optimizer = SGD(self.model.parameters(), lr=cfgs['Online']['kwargs']['lr'])
        pred = self.predict(target_data)
        prob = self.confidence(target_data)
        for iter in range(cfgs['Online']['erm_num']):
            optimizer.zero_grad()
            output = self.model(target_data)
            loss = criterion(output, pred)
            loss = loss * prob.view(-1, 1)
            loss.backward()
            optimizer.step()
        self.model.eval()


class Fix(Base):
    def __init__(self, cfgs=None, seed=None, **algo_kwargs):
        super(Fix, self).__init__(cfgs=cfgs, seed=seed, **algo_kwargs)
        self.optimizer = None
        self.estimate_result = {
            'D': [],
            'G': [],
        }

