
import argparse
import numpy as np
import torch
import pandas as pd

from torchvision import datasets
from torch import nn, optim, autograd
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D

from Data_generating import data_generator,data_high_generator,high_data_loader,data_loader
from additional_functions import mean_nll,mean_nll2,mean_accuracy,prob_sum,softmax,Condi_MI,mean_nll2_forIRM


parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=440)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.002)
parser.add_argument('--lr', type=float, default=0.0004)
parser.add_argument('--n_restarts', type=int, default=5)
parser.add_argument('--steps', type=int, default=501)
parser.add_argument('--high_env_number', type=int, default=5)
parser.add_argument('--high_env', type=np.array, default=np.array([0.1,0.3,0.5,0.7,0.9]))

parser.add_argument('--flip_rate', type=float, default=0.25)
flags = parser.parse_args()


def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

def condi_prob(x):
    return torch.stack([x[:,1]/(x[:,1] + x[:,2]),x[:,2]/(x[:,1] + x[:,2])]).T

class MLP(nn.Module):
        def __init__(self):
            super(MLP, self).__init__()
            lin1 = nn.Linear(2*14*14, flags.hidden_dim)
            lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
            lin3 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
            lin4 = nn.Linear(flags.hidden_dim, 1)     

            low_lin5 = nn.Linear(flags.hidden_dim, 3)



            for lin in [lin1,lin2,lin3, lin4]:
                nn.init.xavier_uniform_(lin.weight)
                nn.init.zeros_(lin.bias)
            for lin in [low_lin5]:
                nn.init.xavier_uniform_(lin.weight)
                nn.init.zeros_(lin.bias)
            self._main2 = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3, lin4)
            self._main3 = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3, low_lin5 )


        def forward_all(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main2(out) 
            return out

        def forward_low(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            out = softmax(out)
            return prob_sum(out)

        def forward_low2(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            out = softmax(out)
            return out
        
        
        def forward_low3(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            out = softmax(out)
            out = condi_prob(out)
            return out
        
def CV(images, labels, split, out_number):
    out_images = images[out_number::split]
    out_labels = labels[out_number::split]
    X =torch.tensor(range(len(images)))
    Y = (X%split == out_number).float()
    leave_images = images[Y==0,:]
    leave_labels = labels[Y==0,:]
    return [{'images':leave_images, 'labels':leave_labels} , {'images':out_images, 'labels':out_labels}]


def hCV(images, labels, envs_labels, split, out_number):
    out_images = images[out_number::split]
    out_labels = labels[out_number::split]
    out_envs_labels = envs_labels[out_number::split]
    X =torch.tensor(range(len(images)))
    Y = (X%split == out_number).float()
    leave_images = images[Y==0,:]
    leave_labels = labels[Y==0,:]
    leave_envs_labels = envs_labels[Y==0,:]
    return [{'images':leave_images, 'labels':leave_labels, 'env_labels':leave_envs_labels} , {'images':out_images, 'labels':out_labels, 'env_labels':out_envs_labels,}]





savingdata_Restart0 = []
savingdata_Restart1 = []
savingdata_Restart2 = []
savingdata_Restart3 = []
savingdata_Restart4 = []

for restart in range(flags.n_restarts):
        print('Restart:',restart)
        
        mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
        mnist_train = (mnist.data[:50000], mnist.targets[:50000])
        mnist_val = (mnist.data[50000:], mnist.targets[50000:])

        mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
        mnist_train = (mnist.data[:50000], mnist.targets[:50000])
        mnist_val = (mnist.data[50000:], mnist.targets[50000:])
        rng_state = np.random.get_state()
        np.random.shuffle(mnist_train[0].numpy())
        np.random.set_state(rng_state)
        np.random.shuffle(mnist_train[1].numpy())

        mnist_train_image = mnist_train[0][::2]
        mnist_train_label = mnist_train[1][::2]

        high_mnist_train_image = mnist_train[0][1::2]
        high_mnist_train_label = mnist_train[1][1::2]

        envs = [
            data_generator(mnist_train_image, mnist_train_label, 0.1,flags.flip_rate),

            data_generator(mnist_val[0][::2], mnist_val[1][::2], 0.1,flags.flip_rate),data_generator(mnist_val[0][1::2], mnist_val[1][1::2], 0.9,flags.flip_rate)
        ]

        high_envs = [ data_high_generator(high_mnist_train_image[i::5], high_mnist_train_label[i::5], j,flags.flip_rate) for i, j in enumerate(flags.high_env) ]

        for env in high_envs:
            ratio  = env['labels'][env['labels'].view(-1)!=0,:].shape[0]/env['labels'].shape[0]
            env['ratio'] = ratio

        iters_array = np.array([0,100,200,300])
        penalty_weight_array = np.linspace(0.0, 9.0,10)
        for iters in iters_array:
            for weight_index in range(10):
                penalty_weight = penalty_weight_array[weight_index]
                print('iters:',iters)
                print('penalty_weight:',penalty_weight)

                print('Flags:')
                for k,v in sorted(vars(flags).items()):
                    print("\t{}: {}".format(k, v))

                print('Starting CV under (iters,weight_grad)=({},{})'.format(iters, penalty_weight))      

                substitute_CV = []
                simplymax_CV = []
                final_train_accs=[]
                final_test1_accs = []
                final_test2_accs = []
 
                
                
                for CV_step in range(10):
                    print('CV_step:',CV_step)
                    envs_leave = [ CV(envs[0]['images'],envs[0]['labels'],10,CV_step)[0],CV(envs[1]['images'],envs[1]['labels'],10,CV_step)[0],CV(envs[2]['images'],envs[2]['labels'],10,CV_step)[0]   ]

                    envs_out = [ CV(envs[0]['images'],envs[0]['labels'],10,CV_step)[1],CV(envs[1]['images'],envs[1]['labels'],10,CV_step)[1],CV(envs[2]['images'],envs[2]['labels'],10,CV_step)[1]     ]

                    envs_high_leave = [ hCV(high_envs[i]['images'],high_envs[i]['labels'],high_envs[i]['env_labels'],10,CV_step)[0] for i in range(flags.high_env_number)]


                    envs_high_out = [ hCV(high_envs[i]['images'],high_envs[i]['labels'],high_envs[i]['env_labels'],10,CV_step)[1]  for i in range(flags.high_env_number) ]


                    mlp_CV = MLP().cuda()

                    def penalty(logits, y):
                        loss = mean_nll2_forIRM(logits , y)
                        grad = autograd.grad(loss, mlp_CV._main2[5].parameters(), create_graph=True)[0]
                        return torch.sum(grad**2)

                    optimizer = optim.Adam(mlp_CV.parameters(), lr=flags.lr)

                    pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty', 'test1 acc', 'test2 acc',)

                    for step in range(flags.steps):
                        for env in envs_leave:
                            logits_low = mlp_CV.forward_low2(env['images'])
                            env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])
                            env['acc'] = mean_accuracy(logits_low, env['labels'])
                        for env in envs_high_leave:
                            logits = mlp_CV.forward_all(env['images'])
                            env['penalty'] = penalty(logits, env['labels'])
                        train_nll = envs_leave[0]['nll']
                        train_acc = envs_leave[0]['acc']
                        train_penalty = torch.stack([envs_high_leave[i]['penalty'] for i in range(flags.high_env_number)]).mean()


                        weight_norm = torch.tensor(0.).cuda()
                        for w in mlp_CV.parameters():
                            weight_norm += w.norm().pow(2)  

                        loss = train_nll.clone()
                        loss += flags.l2_regularizer_weight * weight_norm
                        penalty_weights = (10**penalty_weight
                           if step >= iters else 1.0)
                        loss += penalty_weights * train_penalty
                        if penalty_weights > 1.0:
                          # Rescale the entire loss to keep gradients in a reasonable range
                          loss /= penalty_weights

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()



                        for env in envs_out:
                                logits_low = mlp_CV.forward_low2(env['images'])
                                env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])
                                logits_low2 = mlp_CV.forward_low3(env['images'][env['labels'].view(-1)!=0,:])
                                env['dist'] = mean_nll(torch.log(logits_low2) ,env['labels'][env['labels'].view(-1)!=0,:] -1. )

                        for env in envs_high_out:
                            logits_low = mlp_CV.forward_low(env['images'])
                            env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])

                        for env in envs_high_out:
                            logits_low = mlp_CV.forward_low(env['images'])
                            env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])
                            env['ce'] = (-(1/2)*torch.log(mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,1]/((mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,1])+(mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,2])))
                                             +-(1/2)*torch.log(mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,2]/((mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,1])+(mlp_CV.forward_low2(env['images'][env['labels'].view(-1)!=0,:])[:,2])))).mean()

                        test1_acc = envs_leave[1]['acc']
                        test2_acc = envs_leave[2]['acc']



                        if step % 100 == 0:   
                            pretty_print(
                            np.int32(step),
                            train_nll.detach().cpu().numpy(),
                            train_acc.detach().cpu().numpy(),
                            train_penalty.detach().cpu().numpy(),
                            (train_penalty*penalty_weights).detach().cpu().numpy(),
                            test1_acc.detach().cpu().numpy(),
                            test2_acc.detach().cpu().numpy()
                                          )
                        if step == 500:  
                            print(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1)   for i  in range(flags.high_env_number)]),envs_out[0]['dist'].view(-1) ]).view(-1))
                            print('CV:',torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll'] + envs_out[0]['dist']*high_envs[i]['ratio']).view(-1)   for i  in range(flags.high_env_number)])]).view(-1)))
                            substitute_CV.append(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll'] + envs_out[0]['dist']*high_envs[i]['ratio']).view(-1)   for i  in range(flags.high_env_number)])]).view(-1)).detach().cpu().numpy())
                            simplymax_CV.append(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([envs_high_out[i]['nll'].view(-1)  for i in range(flags.high_env_number)])]).view(-1)).detach().cpu().numpy())
 
 


                print('Starting training (iters,weight_grad)=({},{})'.format(iters, penalty_weight))

                mlp = MLP().cuda()    
                def penalty(logits, y):
                    loss = mean_nll2_forIRM(logits , y)
                    grad = autograd.grad(loss, mlp._main2[5].parameters(), create_graph=True)[0]
                    return torch.sum(grad**2)



                optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)

                pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty', 'test1acc','test2 acc')

                for step in range(flags.steps):
                        for env in envs:
                                logits_low = mlp.forward_low2(env['images'])
                                env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])
                                env['acc'] = mean_accuracy(logits_low, env['labels'])
                        for env in high_envs:
                            logits = mlp.forward_all(env['images'])
                            env['penalty'] = penalty(logits, env['labels'])
                        train_nll = envs[0]['nll']
                        train_acc = envs[0]['acc']
                        train_penalty = torch.stack([high_envs[i]['penalty'] for i in range(flags.high_env_number)]).mean()

                        weight_norm = torch.tensor(0.).cuda()
                        for w in mlp.parameters():
                            weight_norm += w.norm().pow(2)  

                        loss = train_nll.clone()
                        loss += flags.l2_regularizer_weight * weight_norm
                        penalty_weights = ( 10**(penalty_weight) 
                               if step >=iters else 1.0)
                        loss += penalty_weights * train_penalty
                        if penalty_weights > 1.0:
                          # Rescale the entire loss to keep gradients in a reasonable range
                          loss /= penalty_weights

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        test1_acc = envs[1]['acc']
                        test2_acc = envs[2]['acc']
                        if step % 100 == 0:   
                            pretty_print(
                            np.int32(step),
                            train_nll.detach().cpu().numpy(),
                            train_acc.detach().cpu().numpy(),
                            train_penalty.detach().cpu().numpy(),
                            (train_penalty*penalty_weights).detach().cpu().numpy(),
                            test1_acc.detach().cpu().numpy() ,
                            test2_acc.detach().cpu().numpy() )

                final_train_accs.append(train_acc.detach().cpu().numpy())
                final_test1_accs.append(test1_acc.detach().cpu().numpy())
                final_test2_accs.append(test2_acc.detach().cpu().numpy())
                print('Final train acc (mean/std across restarts so far):')
                print(np.mean(final_train_accs), np.std(final_train_accs))
                print('Final test1 acc (mean/std across restarts so far):')
                print(np.mean(final_test1_accs), np.std(final_test1_accs))
                print('Final test2 acc (mean/std across restarts so far):')
                print(np.mean(final_test2_accs), np.std(final_test2_accs))
                if restart ==0:            
                    savingdata_Restart0.append(np.array([iters,penalty_weight,np.mean(substitute_CV),np.mean(simplymax_CV),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))

                if restart ==1:             
                    savingdata_Restart1.append(np.array([np.mean(substitute_CV),np.mean(simplymax_CV),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


                if restart ==2:            
                    savingdata_Restart2.append(np.array([np.mean(substitute_CV),np.mean(simplymax_CV),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))

                if restart ==3:             
                    savingdata_Restart3.append(np.array([np.mean(substitute_CV),np.mean(simplymax_CV),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


                if restart ==4:            
                    savingdata_Restart4.append(np.array([np.mean(substitute_CV),np.mean(simplymax_CV),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


savingdata_Restart0 = np.array(savingdata_Restart0)
savingdata_Restart1 = np.array(savingdata_Restart1)
savingdata_Restart2 = np.array(savingdata_Restart2)
savingdata_Restart3 = np.array(savingdata_Restart3)
savingdata_Restart4 = np.array(savingdata_Restart4)


x =np.concatenate([savingdata_Restart0,savingdata_Restart1,savingdata_Restart2,savingdata_Restart3,savingdata_Restar4] ,axis=1)

#print('x:',x.shape)
sample = pd.DataFrame(x, columns=['iters','penalty_weight','sbstitute_CV','simplymax_CV','test_acc1','test_acc2','sbstitute_CV_1','simplymax_CV_1','test_acc1_res1','test_acc2_res1','sbstitute_CV_2','simplymax_CV_2','test_acc1_res2','test_acc2_res2','sbstitute_CV_3','simplymax_CV_3','test_acc1_res3','test_acc2_res3','sbstitute_CV_4','simplymax_CV_4','test_acc1_res4','test_acc2_res4'])
print(sample)
sample.to_csv('CV_arrange_IRM_result_flip_rate={}_high={}.csv'.format(flags.flip_rate, flags.high_env))