import sys
import numpy as np
from more_itertools import chunked

import torch
from torch.autograd import grad
from torch import optim



from utils.helper import get_dataloader

print('From Inside the Algo Class: ', sys.argv[0])



class BaseAlgo_fish():
    def __init__(self, args, run, cuda, kwargs):
        
        
        self.args= args
        
        
        self.run= run
        self.cuda= cuda
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)    
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]

        self.kwargs=kwargs
        
        
    
    def get_model(self):
        from models import net_for_fish
        phi=net_for_fish.get_model_for_fish(self.args.out_classes,  
                        self.args.img_c, self.args.pre_trained, self.args.os_env)

        print('Model Architecture: ', self.args.model_name)
        phi=phi.to(self.cuda) 
        return phi
    
    def save_model(self,base_res_dir):
        # Store the weights of the model
        torch.save(self.phi.state_dict(), base_res_dir + '/Model' + '.pth')
        np.save( base_res_dir + '/Val_Acc' + '.npy', np.array(self.val_acc) )
        np.save( base_res_dir + '/Test_Acc' +  '.npy', np.array(self.final_acc))
    
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        elif self.args.opt == 'adam':
            opt= optim.Adam([
                        {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
                ], lr= self.args.lr)
        
        return opt

    def get_test_accuracy(self, case,domains):
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset=  get_dataloader( self.args, self.run,domains, 'val', 0, self.kwargs )      

        elif case == 'test':
            dataset= get_dataloader( self.args, self.run, domains, 'test', 0, self.kwargs )
        
        dataset= dataset['data_loader']

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                self.opt.zero_grad()
#                 print(x_e.shape)
#                 print(torch.cuda.memory_allocated())                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                z_e=self.phi.enc(x_e)
                out= self.phi.fc(z_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                

        print(' Accuracy: ', case, 100*test_acc/test_size )               
        return 100*test_acc/test_size
   
    