from tqdm.auto import tqdm
import torch
import torch.nn as nn
import dataset
import torch.utils.data as data
import numpy as np
import model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
class trainer:
    
    def __init__(self, target_net:model.fix_net,LR) -> None:
        self.target_net = target_net
        self.loss = torch.nn.CrossEntropyLoss()
        self.optimizer_target = torch.optim.SGD(self.target_net.parameters(),LR)
        #self.init_target_net_fc()
        nn.init.constant_(self.target_net.fc.weight,0)
        nn.init.constant_(self.target_net.fc.bias,0)
        self.kl_loss = nn.KLDivLoss(reduction='sum')
        
    def init_target_net_fc(self, lr=0.01):
        #self.target_net.fc = nn.Linear(self.target_net.in_features,self.target_net.out_features).to(device)
        nn.init.constant_(self.target_net.fc.weight,0)
        nn.init.constant_(self.target_net.fc.bias,0)
        #self.optimizer_target = torch.optim.SGD(self.target_net.parameters(),lr)
        
    def fixByName(self,filter=['fc']):
        self.target_net.fixByName(filter)    
            
    def pre_train(self, trainloader,curr_epoch=1,n_epoch=1):
        train_loss = []
        train_accs = []
        grad = []
        param = []
        self.target_net.train()
        for data, target in trainloader:
            data = data.to(device)
            target = target.to(device)
            y = self.target_net(data)
            loss = self.loss(y, target)
            self.optimizer_target.zero_grad()
            loss.backward()
            #print(self.target_net.isPositiveGradient())
            #self.target_net.get_grad_num(self.target_net.feature_net.in_layer.weight.grad,torch.mean(self.target_net.feature_net.feature[0],0))
            #self.target_net.get_grad_num(self.target_net.feature_net.hidden_layer[0].weight.grad,torch.mean(self.target_net.feature_net.feature[1],0))
            #self.target_net.get_grad_num(self.target_net.feature_net.out_layer.weight.grad,torch.mean(self.target_net.feature_net.feature[2],0))
            #grad.append(self.target_net.feature_net.in_layer.weight.grad.detach().clone())
            #grad.append(self.target_net.fc.weight.grad.detach().clone())
            #param.append(self.target_net.feature_net.in_layer.weight.data.detach().clone())
            self.optimizer_target.step()
            acc = (y.argmax(dim=-1) == target.to(device)).float().mean()
            train_loss.append(loss.item())
            train_accs.append(acc)
        train_loss = sum(train_loss) / len(train_loss)
        train_acc = sum(train_accs) / len(train_accs)
        print(f"[ pre train | {1 + curr_epoch:03d}/{n_epoch:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
        return grad,param
        
    def test_feature(self, dataloader):
        features = []
        self.target_net.eval()
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)
            with torch.no_grad():
                y = self.target_net.feature_forward(data)
                features.append(y)
        return features
    
    def test(self, dataloader):
        test_loss = []
        test_accs = []
        self.target_net.eval()
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)
            with torch.no_grad():
                y = self.target_net(data)
                loss = self.loss(y, target)
            acc = (y.argmax(dim=-1) == target.to(device)).float().mean()
            test_loss.append(loss.item())
            test_accs.append(acc)
        test_loss = sum(test_loss) / len(test_loss)
        test_acc = sum(test_accs) / len(test_accs)
        print(f"[ Test ] loss = {test_loss:.5f}, acc = {test_acc:.5f}")