

import argparse
import numpy as np
import torch
import pandas as pd

import os, csv
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


from torchvision import datasets
from torch import nn, optim, autograd
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold

from additional_functions import mean_nll,mean_nll2,mean_accuracy,prob_sum,softmax,Condi_MI,mean_nll2_forIRM


parser.add_argument('--hidden_dim', type=int, default=440)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.0004)
parser.add_argument('--n_restarts', type=int, default=1)
parser.add_argument('--steps', type=int, default=5)
flags = parser.parse_args()


#........ means directory name
metadata_df = pd.read_csv('.../metadata.csv')

scale = 256.0/224.
target_resolution = (224, 224)
minibatch = 128

saving_data = []


transform = transforms.Compose([
            transforms.Resize((int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor()
           # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        

#........ means directory name


directory = os.listdir('.......')
image_ocean = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('........', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    image_ocean.append(img)


ocean_data = torch.stack(image_ocean,dim=0)
    
directory = os.listdir('........')
image_boo = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('........', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    image_boo.append(img)
   
boo_data = torch.stack(image_boo, dim=0)

y_array = metadata_df['y'].values
confounder_array = metadata_df['place'].values
group_array = (y_array*(2) + confounder_array).astype('int')
filename_array = metadata_df['img_filename'].values

image_cub = []

for image_filename in filename_array:
    #print(image_name)
    file_name = os.path.join('........, image_filename)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    image_cub.append(img)
    
cub_data = torch.stack(image_cub, dim=0)


group0 = {}
group1 = {}
group2 = {}
group3 = {}

group0['iamges']  =cub_data[group_array==0,:,:,:]
group0['labels']  = y_array[group_array==0]

group1['iamges']  = cub_data[group_array==1,:,:,:]
group1['labels']  = y_array[group_array==1]

group2['iamges']  = cub_data[group_array==2,:,:,:]
group2['labels']  = y_array[group_array==2]

group3['iamges']  =cub_data[group_array==3,:,:,:]
group3['labels']  = y_array[group_array==3]


env1 = {}
env2 = {}
env1['images'] = torch.cat([group0['iamges'],group3['iamges'], ocean_data ])
env1['labels'] = torch.cat([torch.from_numpy(group0['labels']),torch.from_numpy(group3['labels']),(torch.ones(ocean_data.shape[0])+1)]).long()
env2['images'] = torch.cat([group1['iamges'],group2['iamges'], boo_data  ])
env2['labels'] = torch.cat([torch.from_numpy(group1['labels']),torch.from_numpy(group2['labels']),(torch.ones(boo_data.shape[0])+1)]).long()

torch.set_printoptions(edgeitems=100000)





rng_state = np.random.get_state(1)
np.random.shuffle(env1['images'].numpy())
np.random.set_state(rng_state)
np.random.shuffle(env1['labels'].numpy())


rng_state = np.random.get_state(2)
np.random.shuffle(env2['images'].numpy())
np.random.set_state(rng_state)
np.random.shuffle(env2['labels'].numpy())


env1_train = {}
env2_train = {}

env1_train['images'] = env1['images'][env1['images'].shape[0]//5:]
env1_train['labels'] = env1['labels'][env1['images'].shape[0]//5:]

env2_train['images'] = env2['images'][env2['images'].shape[0]//5:]
env2_train['labels'] = env2['labels'][env2['images'].shape[0]//5:]



envs_train = [env1_train,env2_train]


env1_test = {}
env2_test = {}

env1_test['images'] = env1['images'][:env1['images'].shape[0]//5]
env1_test['labels'] = env1['labels'][:env1['images'].shape[0]//5]

env2_test['images'] = env2['images'][:env2['images'].shape[0]//5]
env2_test['labels'] = env2['labels'][:env2['images'].shape[0]//5]




envs_test = [env1_test,env2_test]

ratio  = (env2_train['labels'].view(-1,1))[env2_train['labels'].view(-1)!=0,:].shape[0]/env2_train['labels'].shape[0]


    
class Mydatasets_env1(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env1_train['images']
        self.label = env1_train['labels'].view(-1,1)

        self.datanum =  env1_train['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
class Mydatasets_env2(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env2_train['images']
        self.label = env2_train['labels'].view(-1,1)

        self.datanum =  env2_train['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
class Mydatasets_env1_test(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env1_test['images']
        self.label = env1_test['labels'].view(-1,1)

        self.datanum =  env1_test['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
class Mydatasets_env2_test(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env2_test['images']
        self.label = env2_test['labels'].view(-1,1)

        self.datanum =  env2_test['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
    
trainset1= Mydatasets_env1()
trainset1_test= Mydatasets_env1_test()
trainset2= Mydatasets_env2()
trainset2_test= Mydatasets_env2_test()

import torch
import torch.nn as nn

class block(nn.Module):
    def __init__(self, first_conv_in_channels, first_conv_out_channels, identity_conv=None, stride=1):
        
                             
        super(block, self).__init__()

       
        self.conv1 = nn.Conv2d(
            first_conv_in_channels, first_conv_out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(first_conv_out_channels)


        self.conv2 = nn.Conv2d(
            first_conv_out_channels, first_conv_out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(first_conv_out_channels)


        self.conv3 = nn.Conv2d(
            first_conv_out_channels, first_conv_out_channels*4, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(first_conv_out_channels*4)
        self.relu = nn.ReLU()

        self.identity_conv = identity_conv

    def forward(self, x):

        identity = x.clone()  
        x = self.conv1(x)  
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)  
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)  
        x = self.bn3(x)

       
        if self.identity_conv is not None:
            identity = self.identity_conv(identity)
        x += identity

        x = self.relu(x)

        return x
    
    
class ResNet(nn.Module):
    def __init__(self, block):
        super(ResNet, self).__init__()

    
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        
        self.conv2_x = self._make_layer(block, 3, res_block_in_channels=64, first_conv_out_channels=64, stride=1)       
        self.conv3_x = self._make_layer(block, 4, res_block_in_channels=256,  first_conv_out_channels=128, stride=2)
        self.conv4_x = self._make_layer(block, 6, res_block_in_channels=512,  first_conv_out_channels=256, stride=2)
        self.conv5_x = self._make_layer(block, 3, res_block_in_channels=1024, first_conv_out_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*4, 256)
        self.fc1 = nn.Linear(256, 3)
        self.fc2 = nn.Linear(256, 1)
        
        
        self._main1 = nn.Sequential(self.fc1)
        self._main2 = nn.Sequential(self.fc2)

    def forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main1(x)
        
 
        return softmax(x)
    
    def high_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main2 (x)

        return x
    def pre_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        #x = self._main1(x)
        
         

        return x
    
    def condi_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main1(x)
        x = softmax(x)
 
        return prob_sum(x)
    

    def _make_layer(self, block, num_res_blocks, res_block_in_channels, first_conv_out_channels, stride):
        layers = []

        identity_conv = nn.Conv2d(res_block_in_channels, first_conv_out_channels*4, kernel_size=1,stride=stride)
        layers.append(block(res_block_in_channels, first_conv_out_channels, identity_conv, stride))

        in_channels = first_conv_out_channels*4

        for i in range(num_res_blocks - 1):
            layers.append(block(in_channels, first_conv_out_channels, identity_conv=None, stride=1))

        return nn.Sequential(*layers)

def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

def condi_prob(x):
    return torch.stack([x[:,1]/(x[:,1] + x[:,2]),x[:,2]/(x[:,1] + x[:,2])]).T


iters_array = np.array([0, 1,2,3,4])
penalty_weight_array = np.linspace(0, 4.0,5)
saving_data = []
for iters in iters_array:
        for weight_index in range(5):
            penalty_weight = penalty_weight_array[weight_index]
            #iters=3
            #penalty_weight =1
            print('iters:',iters)
            print('penalty_weight:',penalty_weight)
            
            #model = torchvision.modls.resnet50(pretrained=True).cuda()
            #d = model.fc.in_featurs
            #model.fc = nn.Linear(d, 3).cuda()

            final_train_accs = []
            final_test_accs_env1 = []
            final_test_accs_env2 = []
            sbs_CV_store = []
            max_CV_store = []
            print('CV training')
            
            kf = KFold(n_splits=5)
            #kf = KFold(n_splits=2)
            for _fold, (train_index, valid_index) in enumerate(kf.split(trainset1)):
                print('CV_step=',_fold)
                model_CV = ResNet(block).cuda()
                
                #print(train_index)
                #print(valid_index)
                
                train1_dataset  = Subset(trainset1, train_index)
                train2_dataset  = Subset(trainset2, train_index)
                
                

                
                
                
                trainloader1 = torch.utils.data.DataLoader(train1_dataset  , batch_size =56, shuffle = True, num_workers=6)
                #outloader1 = torch.utils.data.DataLoader(out1_dataset, batch_size = 14, shuffle = True, num_workers=6)
                trainloader2 = torch.utils.data.DataLoader(train2_dataset, batch_size = 56, shuffle = True, num_workers=6)
                #outloader2 = torch.utils.data.DataLoader(out2_dataset, batch_size = 14, shuffle = True, num_workers=6)

                def penalty(logits, y):
                    loss = mean_nll2_forIRM(logits , y)
                    grad = autograd.grad(loss, model_CV._main2[0].parameters(), create_graph=True)[0]
                    return torch.sum(grad**2)



                optimizer = optim.Adam(model_CV.parameters(), lr=flags.lr)


                pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty')

                for step in range(flags.steps):
                
                    train_acc_store = []
                    train_nll_store = []
                    penalty_store = []
                    
                    for (batch1, batch2) in zip(trainloader1, trainloader2):
                            batch1 = tuple(t.cuda() for t in batch1)
                            batch2 = tuple(t.cuda() for t in batch2)

                            images1 = batch1[0]
                            labels1 = batch1[1]

                            images2 = batch2[0]
                            labels2 = batch2[1]       
                            pre_logits1 = model_CV.pre_forward(images1)
                            pre_logits2 = model_CV.pre_forward(images2)
                            train_nll = mean_nll(torch.log(softmax(model_CV._main1(pre_logits1))) ,labels1)


                            train_acc = mean_accuracy(softmax(model_CV._main1(pre_logits1)), labels1)

                          
                            penalty1 = penalty( model_CV._main2(pre_logits1), (labels1 >0 ).float().view(-1,1) )
                            penalty2 = penalty( model_CV._main2(pre_logits2), (labels2 >0 ).float().view(-1,1) )

                            print('train_nll:',train_nll)

                            train_penalty = torch.stack([penalty1, penalty2]).mean()

                            loss = train_nll.clone()
      
                            weight_norm = torch.tensor(0.).cuda()
                            for w in model_CV.parameters():
                                weight_norm += w.norm().pow(2)
                            loss += flags.l2_regularizer_weight * weight_norm
                            penalty_weights = ( 10**penalty_weight
                                                                if step >=iters else 1.0 )
                            loss += penalty_weights * train_penalty
                            if penalty_weights > 1.0:
                                  # Rescale the entire loss to keep gradients in a reasonable range
                                    loss /= penalty_weights

                            train_nll_store.append(train_nll.detach().cpu())
                            train_acc_store.append(train_acc.detach().cpu())
                            penalty_store.append(train_penalty.detach().cpu())


                            optimizer.zero_grad()
                            loss.backward()
                            del images1
                            del images2
                            del labels1
                            del labels2
                            del train_nll
                            del train_acc
                            del batch1
                            del batch2


                            del pre_logits1
                            del pre_logits2 
                            del penalty1
                            del penalty2
                            del loss
                            
                            optimizer.step()
                            #print('weights:',penalty_weights)
                            #print('weights_and_p:',penalty_weights*torch.stack([penalty1,penalty2]).mean())
                            
                    train_acc = torch.stack(train_acc_store, dim=0).mean()                      
                    train_nll = torch.stack(train_nll_store, dim=0).mean()
                        #print('train_penalty:',penalty_store)
                    train_penalty = torch.stack(penalty_store, dim=0).mean()
                    if step % 1 == 0:   
                        pretty_print(
                        np.int32(step),
                        train_nll.detach().cpu().numpy() ,
                        train_acc.detach().cpu().numpy() ,
                        train_penalty.detach().cpu().numpy(),
                         (train_penalty*penalty_weights).detach().cpu().numpy() )
                            
                final_train_accs.append(train_acc.detach().cpu().numpy())
                print('Final train acc (mean/std across restarts so far):')
                print(np.mean(final_train_accs), np.std(final_train_accs))
                
                
                del train_acc_store
                del train_nll_store
                del penalty_store
               
                
                out_nll1_store = []
                out_nll2_store = []
                out_bias_store = []
                
                train1_dataset  = Subset(trainset1, valid_index)
                train2_dataset  = Subset(trainset2, valid_index)
                
                
                trainloader1  = torch.utils.data.DataLoader(train1_dataset, batch_size = 14, shuffle = True, num_workers=6)
                trainloader2  = torch.utils.data.DataLoader(train2_dataset, batch_size = 14, shuffle = True, num_workers=6)
                
                for (batch1, batch2) in zip(trainloader1,trainloader2):
                        batch1 = tuple(t.cuda() for t in batch1)                        
                        images1 = batch1[0]
                        labels1 = batch1[1]

                        logits1 = model_CV.forward(images1)
                       

                        out_nll1 = mean_nll(torch.log(logits1) ,labels1)
                        out_nll1_store.append( out_nll1.detach().cpu() )
                        print('out_nll1 :', out_nll1) 
                        
                        
                        logits1_bias = model_CV.condi_forward(images1[labels1.view(-1)!=0,:])                       
                        out_bias = mean_nll(torch.log(logits1_bias) ,labels1[labels1.view(-1)!=0,:] -1. )
                        
                        print('out_bias :', out_bias)                       
                        out_bias_store.append(out_bias.detach().cpu())                 
                        
                        del images1
                        del labels1 
                        del logits1
                        del out_nll1
                        del batch1
                        del logits1_bias 
                        del out_bias
           
                        batch2 = tuple(t.cuda() for t in batch2)
                        images2 = batch2[0]
                        labels2 = batch2[1]            
                        logits2 = model_CV.condi_forward(images2)
                        out_nll2 =  mean_nll( torch.log(logits2), (labels2 >0 ).float().view(-1,1)  )
                        out_nll2_store.append(out_nll2.detach().cpu())
                        
                        print('out_nll2 :', out_nll2) 
                        
                        del images2
                        del labels2
                        del logits2
                        del out_nll2
                        del batch2                                                                                                        
                                                                                                         
                
                out_nll1 = torch.stack(out_nll1_store, dim=0).mean()
                out_nll2 = torch.stack(out_nll2_store, dim=0).mean()
                out_bias = torch.stack(out_bias_store, dim=0).mean()

                sbs_CV_store.append( torch.max(torch.cat([out_nll1.view(-1), (out_nll2 + out_bias*ratio).view(-1)  ]  ) ) )
                max_CV_store.append( torch.max(torch.cat( [ out_nll1.view(-1), out_nll2.view(-1) ] ) ) )
                print('sbs_CV:',torch.max(torch.cat([out_nll1.view(-1), (out_nll2 + out_bias*ratio).view(-1)  ]  ) )  )
                print('max_CV:',  torch.cat( [ out_nll1.view(-1), out_nll2.view(-1) ] ) )
                del out_nll1
                del out_nll2
                del out_bias

            print('final_sbs:', np.mean(sbs_CV_store))
            print('final_max:', np.mean(max_CV_store))
                

                        
            print('start training')
            

            final_train_accs = []
            final_test_accs_env1 = []
            final_test_accs_env2 = []
            for restart in range(flags.n_restarts):
                print('Restart:',restart)
                model = ResNet(block).cuda()
                
                trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size =56, shuffle = True, num_workers=8)
                trainloader1_test = torch.utils.data.DataLoader(trainset1_test, batch_size = 56, shuffle = True, num_workers=4)
                trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size = 56, shuffle = True, num_workers=4)
                trainloader2_test = torch.utils.data.DataLoader(trainset2_test, batch_size = 56, shuffle = True, num_workers=8)

                def penalty(logits, y):
                    loss = mean_nll2_forIRM(logits , y)
                    grad = autograd.grad(loss, model._main2[0].parameters(), create_graph=True)[0]
                    return torch.sum(grad**2)



                optimizer = optim.Adam(model.parameters(), lr=flags.lr)


                pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty', 'test1 acc', 'test2_acc')

                for step in range(flags.steps):
                        train_acc_store = []
                        train_nll_store = []
                        labels_store = []
                        test1_acc_store = []
                        test2_acc_store = []
                        penalty_store = []
                        
                        for (batch1, batch2) in zip(trainloader1, trainloader2):
                                batch1 = tuple(t.cuda() for t in batch1)
                                batch2 = tuple(t.cuda() for t in batch2)

                                images1 = batch1[0]
                                labels1 = batch1[1]

                                images2 = batch2[0]
                                labels2 = batch2[1]       
                                pre_logits1 = model.pre_forward(images1)
                                pre_logits2 = model.pre_forward(images2)
                                train_nll = mean_nll(torch.log(softmax(model._main1(pre_logits1))) ,labels1)
                                
                                
                                train_acc = mean_accuracy(softmax(model._main1(pre_logits1)), labels1)
                                penalty1 = penalty( model._main2(pre_logits1), (labels1 >0 ).float().view(-1,1) )
                                penalty2 = penalty( model._main2(pre_logits2), (labels2 >0 ).float().view(-1,1) )
                                
                                print('train_nll:',train_nll)
                      
                                train_penalty = torch.stack([penalty1, penalty2]).mean()
                                
                                loss = train_nll.clone()
                                
                                weight_norm = torch.tensor(0.).cuda()
                                for w in model.parameters():
                                    weight_norm += w.norm().pow(2)
                                loss += flags.l2_regularizer_weight * weight_norm
                                penalty_weights = ( 10**penalty_weight
                                                                    if step >=iters else 1.0 )
                                loss += penalty_weights * train_penalty
                                if penalty_weights > 1.0:
                                      # Rescale the entire loss to keep gradients in a reasonable range
                                        loss /= penalty_weights
                                        
                                train_nll_store.append(train_nll.detach().cpu())
                                train_acc_store.append(train_acc.detach().cpu())
 
                                optimizer.zero_grad()
                                loss.backward()
                                optimizer.step()

                                del batch1
                                del batch2
                                del images1
                                del images2
                                del pre_logits1
                                del pre_logits2
                                del train_nll
                                del train_acc
                                del train_penalty
                                del loss
                                del labels1
                                del labels2
                                del weight_norm
                                del penalty1
                                del penalty2


                        for (batch1, batch2) in zip(trainloader1_test,trainloader2_test):
                                batch1 = tuple(t.cuda() for t in batch1)
                                batch2 = tuple(t.cuda() for t in batch2)
                                
                                

                                images1 = batch1[0]
                                labels1 = batch1[1]
                                
                                images2 = batch2[0]
                                labels2 = batch2[1]
                               

                                
                                logits1 = model.forward(images1)
                                logits2 = model.forward(images2)
                                test1_acc = mean_accuracy(logits1, labels1)
                                test2_acc = mean_accuracy(logits2, labels2)
                                
                                print('test1_acc:',test1_acc)
                                print('test2_acc:',test2_acc)
                                

                                test1_acc_store.append(test1_acc.detach().cpu())
                                test2_acc_store.append(test2_acc.detach().cpu())


                        
                        train_acc = torch.stack(train_acc_store, dim=0).mean()
                        test_acc_env1 = torch.stack(test1_acc_store, dim=0).mean()
                        test_acc_env2 = torch.stack(test2_acc_store, dim=0).mean()
                        train_nll = torch.stack(train_nll_store, dim=0).mean()
                        #train_penalty = torch.stack(penalty_store, dim=0).mean()
                        #print(penalty_store)
                        if step % 1 == 0:   
                            pretty_print(
                            np.int32(step),
                            train_nll.detach().cpu().numpy() ,
                            train_acc.detach().cpu().numpy() ,
                            test_acc_env1.detach().cpu().numpy() ,
                            test_acc_env2.detach().cpu().numpy()  )

                final_train_accs.append(train_acc.detach().cpu().numpy())
                final_test_accs_env1.append(test_acc_env1.detach().cpu().numpy())
                final_test_accs_env2.append(test_acc_env2.detach().cpu().numpy())
                print('Final train acc (mean/std across restarts so far):')
                print(np.mean(final_train_accs), np.std(final_train_accs))
                print('Final test acc env1(mean/std across restarts so far):')
                print(np.mean(final_test_accs_env1), np.std(final_test_accs_env1))
                print('Final test acc env2(mean/std across restarts so far):')
                print(np.mean(final_test_accs_env2), np.std(final_test_accs_env2))
                if restart ==0:                                              
                       saving_data.append([iters, penalty_weight, np.mean(sbs_CV_store), np.mean(max_CV_store), np.mean(final_test_accs_env1),  np.std(final_test_accs_env1), np.mean(final_test_accs_env2),  np.std(final_test_accs_env2)])
sample = pd.DataFrame(saving_data, columns=['iters','penalty_weight1','substitute_CV','simplymax_CV', 'test1_acc','test1_std', 'test2_acc','test2_std'])
print(sample)  

sample.to_csv('IRM_result_CV_restart1.csv')





