import copy
import os
import torch
from torch.nn import functional as F
import torch.optim as optim
from .algo_fish import BaseAlgo_fish
from algorithms.fish_utils import  fish_step, get_domain
import numpy as np
import random 
import datetime
import torch.utils.data as data_utils
class get_each_domain_data(data_utils.Dataset):
    def __init__(self, data,lables):
        self.data = data
        self.lables = lables
        self.transform = None ###can be other values
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data.shape)==3:
            self.data= data.unsqueeze(1) 
    def __len__(self):
        return self.lables.shape[0]
    def __getitem__(self, index):
        x = self.data[index]
        y = self.lables[index]
       
        if self.transform is not None:
            x = self.transform(x)
        return x, y
    def get_size(self):
        return self.lables.shape[0]
class fish(BaseAlgo_fish):
    def __init__(self, args, run, cuda, kwargs):
        
        super().__init__(args, run, cuda, kwargs) 
        self.max_epoch=-1
        self.max_i=-1
        self.max_val_acc=0.0
        self.pointer=0

        runId = datetime.datetime.now().isoformat().replace(':', '_')
        self.current_dir=(
                'results/' + self.args.dataset_name + '/' + self.args.method_name 
                + '/' + 'train_' +str(runId)
                )
        #self.batch_left=61*[19]
        #self.batch_left=61*[9]
        if not os.path.exists(self.current_dir):
            os.makedirs(self.current_dir)

    # def update_batch_left(self, domains):
    #     for domain in domains:
    #         #self.batch_left[domain-15]=(self.batch_left[domain-15]-1)%20
    #         self.batch_left[domain-15]=(self.batch_left[domain-15]-1)%10
    def get_dataloader_fish(self,args, run,index, domains, data_case, kwargs):
        data_dir= 'data/datasets'+'/mnist/' + args.dataset_name + '_' + args.mnist_case + '/' 
      
        load_dir= data_dir + data_case + '/' + 'seed_' + str(run) + '_domain_' + str(domains[index])
        mnist_imgs= torch.load( load_dir +  '_org_data.pt')
        mnist_labels= torch.load( load_dir +  '_label.pt')

        ##start level-2
        # dir= 'results/two-level/fish/'

        # load_dir=(
        #         dir + args.dataset_name + '/' + str(domains)
        #     )
        # indecies=np.load( load_dir + 'indecies' + '.npy' )
        # res=indecies[index*40:(index+1)*40]
               
        # inx=[]
        # curr_data=0
        # for j in res:
        #     if j==1:
        #         a=curr_data+np.array(list(range(50)))
        #         # a.tolist()
        #         # print('a')
        #         # print(a.tolist())
        #         inx.append(a.tolist())
        #     curr_data+=50
        # INX=[]
        # #print(inx)
        # for k in inx:
        #     INX+=k
        # #print(INX)
        # INX=np.array(INX)
        # mnist_labels = mnist_labels[INX]
        # mnist_imgs = mnist_imgs[INX]

        #finish level-2

        # print('len mnist-img')
        # print(len(mnist_imgs))
        y = torch.eye(10)
        mnist_labels = y[mnist_labels]
        #batch_size=100####这里 原512
        batch_size=64
        data_domain_obj=get_each_domain_data( mnist_imgs, mnist_labels)
        data_domain=data_utils.DataLoader(data_domain_obj, batch_size=batch_size, shuffle=True,  **kwargs )

        return data_domain

    def train(self):
        #random.seed()
        
        for epoch in range(self.args.epochs):

            self.phi.train()
            i = 0
            print('\n====> Epoch: {:03d} '.format(epoch))
            opt_inner_pre = None
            # domains=random.sample(range(15,76),5)
            # domains.sort()
            # domains=random.sample(range(0,78),6)
            # domains.sort()
            
            
            while i<300:
                #print(self.batch_left)
                train_acc=0.0
                train_size=0
                #domains= get_domain(i,self.run,self.args,self.phi,self.kwargs,self.cuda)
                model_inner = copy.deepcopy(self.phi)
                model_inner.train()
                opt_inner = optim.SGD([
                         {'params': filter(lambda p: p.requires_grad,model_inner.parameters() ) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True ) 
                if opt_inner_pre is not None:# and args.reload_inner_optim:###如果有之前的内圈模型就加载
                    opt_inner.load_state_dict(opt_inner_pre)
                # inner loop update
                #####random
                #domains=random.sample(range(15,76),5)
                # domains.sort()
                ####dpp
                
                domains_list_1l_fm=[[42, 46, 57, 63, 74], [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65], [49, 50, 55, 57, 60],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]
                domains_list_1l_rm=[[15, 26, 37, 42, 70],[16, 18, 33, 37, 44], [17, 39, 51, 62, 68], [28, 43, 50, 64, 68],[18, 22, 39, 60, 63], [21, 30, 41, 42, 44], [16, 17, 34, 40, 62], [27, 29, 35, 39, 59], [15, 36, 38, 46, 56], [22, 23, 33, 41, 65], [24, 33, 38, 42, 50], [21, 29, 32, 36, 38], [15, 18, 35, 53, 60], [15, 28, 30, 33, 65], [46, 49, 58, 62, 75], [21, 24, 48, 49, 57],[32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]
                domains=domains_list_1l_fm[i%len(domains_list_1l_fm)]
                
                #domains=random.sample((list(range(15,76))),5)


                # if (i+1) % 10==0 and (i+1)%300!=0:
                #     domains=get_domain(i,self.run,self.args,self.phi,self.kwargs,self.cuda)
                #     print('domains')
                #     print(domains)
                #domains=get_domain(i,self.run,self.args,self.phi,self.kwargs,self.cuda,self.batch_left)
                #print('domains: '+str(domains))

                for j in range(len(domains)):
                    data = self.get_dataloader_fish(self.args, self.run,j, domains, data_case='train', kwargs=self.kwargs)#####按domain加载数据
                    
                    for batch_idx, (x_e, y_e) in enumerate(data):####这里或许可以分batch操作
                        # if batch_idx==9-self.batch_left[domain-15]:#####19/9
                            #print('batch_idx'+str(batch_idx))#####2000 20 batch
                            opt_inner.zero_grad()

                            x_e= x_e.to(self.cuda)
                            y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                            #Forward Pass
                            z_e= model_inner.enc(x_e)
                            out=model_inner.fc(z_e)
                            loss_e= F.cross_entropy(out, y_e.long()).to(self.cuda)
                            
                            loss_e.backward()#retain_graph=False
                            opt_inner.step()

                            del loss_e
                            torch.cuda.empty_cache()

                            train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                            train_size+= y_e.shape[0]
                
                #self.update_batch_left(domains)

                #print('Train Acc Env : ', 100*train_acc/train_size )
                self.train_acc.append( 100*train_acc/train_size )
                opt_inner_pre = opt_inner.state_dict()
                # fish update
                meta_weights = fish_step(meta_weights=self.phi.state_dict(),
                                        inner_weights=model_inner.state_dict(),
                                        meta_lr=self.args.meta_lr / self.args.meta_steps)####args 的meta_lr,meta_steps
                self.phi.reset_weights(meta_weights)#####reset之后在所有domain上跑再dpp
                #print('finish fish '+str(i)+'-'+'epoch: '+str(epoch))


                if (i + 1) % 30 == 0: ######在相隔其他轮数dpp的时候要考虑用什么domain 作为验证
                    print(f'iteration {(i + 1):05d}: ')
                    self.val_acc.append( self.get_test_accuracy('val',domains) )
                    self.final_acc.append( self.get_test_accuracy('test',self.args.test_domains) )
                    self.phi.train()
                    if self.val_acc[-1] > self.max_val_acc: 
                        self.max_val_acc=self.val_acc[-1]
                        self.max_epoch= int(self.pointer/30)
                        self.save_model(self.current_dir)
                    print(self.val_acc)
                    print(self.final_acc)
                    print('Current Best Epoch: ', self.max_epoch)
                    print( ' with Test Accuracy: ', self.final_acc[self.max_epoch])
                    # domains= get_domain(i,self.run,self.args,self.phi,self.kwargs,self.cuda)
                    # print('domains')
                    # print(domains)
                #Save the model if current best epoch as per validation loss

                      
                
                i += 1
                self.pointer+=1



        

