import torch
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import torch.nn.functional as F

def visualize2D(trainer, writer, global_step, net, dataset, test_loader, prob_list, group_names, mean_regularizer,
                overall_reg_group):
    if writer is None:
        return
    grid_size = trainer.opt.grid_size
    grid_range = trainer.opt.grid_range
    grid1, grid2 = torch.meshgrid(torch.linspace(-grid_range, grid_range, grid_size),
                                  torch.linspace(-grid_range, grid_range, grid_size))

    input_gridxy = torch.cat([grid1.reshape(-1, 1), grid2.reshape(-1, 1)], dim=1)

    out = net(input_gridxy.to(trainer.device)).squeeze()
    out_mesh = out.reshape(grid_size, grid_size).detach().to('cpu')

    fig, ax = plt.subplots(1, 4, figsize=(20, 6))

    tot_rotated_data, tot_y = dataset.tensors

    ax[0].contourf(grid1, grid2, out_mesh)
    ax[0].scatter(tot_rotated_data[:, 0].to('cpu').numpy(), tot_rotated_data[:, 1].to('cpu').numpy(),
                  c=tot_y.to('cpu').numpy(), vmin=0, vmax=1)
    ax[0].scatter(tot_rotated_data[:, 0].to('cpu').numpy(), tot_rotated_data[:, 1].to('cpu').numpy(),
                  c=tot_y.to('cpu').numpy(), vmin=0, vmax=1)
    ax[0].set_title("Classifier Boundaries XY")

    for prob in prob_list:
        ax[1].plot(list(range(len(prob))), prob)

    ax[1].legend(group_names)
    ax[1].set_title("Categorical Probabilities")

    ax[2].plot(list(range(len(mean_regularizer))), mean_regularizer)
    ax[2].set_title("Regularizer Value")
    for i in range(overall_reg_group.shape[1]):
        ax[3].plot(list(range(overall_reg_group.shape[0])), overall_reg_group[:, i].to('cpu'))
    ax[3].set_title("Regularizer per group")
    ax[3].legend(group_names)

    writer.add_figure('Training_Figures', fig, global_step)
    data_thetas = torch.zeros(trainer.opt.test_size, 3, device=trainer.device)
    with torch.no_grad():
        for epoch in trange(3):
            rotated_data, y = iter(test_loader).next()

            input = rotated_data
            out = net(input).squeeze()

            y_g = y.reshape(-1)
            fig, ax = plt.subplots(1, 1)

            ax.scatter(input[:, 0].to('cpu').numpy(), input[:, 1].to('cpu').numpy(), c=out.to('cpu').numpy(),
                          vmin=0, vmax=1)
            ax.set_title("Test Classifier XY")

    writer.add_figure('Test_Figures', fig, global_step=global_step)
    plt.close('all')

def visualizeMnist(trainer, global_step, net, dataset, test_loader, prob_list, group_names, mean_regularizer,
                overall_reg_group, sampler=None, action=None, test_augm_samples=5):

     
    writer = trainer.writer
    grid_size = trainer.opt.grid_size
    grid_range = trainer.opt.grid_range
 
     
    if writer is not None:
        fig,ax = plt.subplots(1,3, figsize=(20,6))
        
        for prob in prob_list:
            ax[0].plot(list(range(len(prob))), prob)

        ax[0].legend(group_names)
        ax[0].set_title("Categorical Probabilities")

        ax[1].plot(list(range(len(mean_regularizer))), mean_regularizer)
        ax[1].set_title("Regularizer Value")
        for i in range(overall_reg_group.shape[1]):
            ax[2].plot(list(range(overall_reg_group.shape[0])), overall_reg_group[:, i].to('cpu'))
        ax[2].set_title("Regularizer per group")
        ax[2].legend(group_names)

        writer.add_figure('Train_Figs/Training_Figures', fig, global_step)
    
    correct=0
    correct_augm=0
    number_samples=0
    net.eval()
    with torch.no_grad():
        for batch_id, (data,target) in enumerate(tqdm(test_loader)):
            b_size=data.shape[0]
            if(b_size!=trainer.opt.test_size):
                continue
            data, target= data.to(trainer.device), target.to(trainer.device)
            batch_size=data.shape[0]
            output=F.log_softmax(net(data),dim=1)
            pred=output.argmax(dim=1,keepdim=True)
            correct+=pred.eq(target.view_as(pred)).sum().item()
            number_samples+=pred.shape[0]
            

            data_sample, _,_=sampler(test_augm_samples*b_size)
            data_sample = data_sample.detach()
            #print("Data Sample:",data_sample)       
            data_augm=action(data_sample,data)
            #pre_input shape (number_samples,batch_size,3)
            #input=pre_input.reshape(number_samples*batch_size,-1)
            data_augm = data_augm.flatten(start_dim=0,end_dim=1)
            output_augm = F.log_softmax(net(data_augm),dim=1)
            out_augm_prob = torch.exp(output_augm) #Assumes LogSoftmax Output
            mean_augm_prob = out_augm_prob.reshape(test_augm_samples, b_size, -1).mean(dim=0).squeeze(0)
            pred = mean_augm_prob.argmax(dim=1, keepdim=True)
            #print("################")
            #print(data.shape)
            #print(data_sample.shape)
            #print(b_size)
            #print(test_augm_samples)
            #print(target.shape)
            #print(data_augm.shape)
            #print(mean_augm_prob.shape)

            correct_augm += pred.eq(target.view_as(pred)).sum().item()
            
    net.train()
     
    if writer is not None:
        writer.add_scalar('Accuracy/'+str(trainer.opt.dataset)+'_Test_Accuracy',correct/number_samples,global_step)
        writer.add_scalar('Accuracy/'+str(trainer.opt.dataset)+'_TestAugm_Accuracy', correct_augm/number_samples,global_step)

    print("Test Accuracy="+str(correct/number_samples))
    print("TestAugm_Accuracy="+str(correct_augm/number_samples))
    plt.close('all')

def visualizeWilds(trainer, global_step, net, dataset, test_loader, prob_list, group_names, mean_regularizer,
                overall_reg_group, sampler=None, action=None, test_augm_samples=5,group_list=None):

     
    writer = trainer.writer
    grid_size = trainer.opt.grid_size
    grid_range = trainer.opt.grid_range
 
     
    if writer is not None:
        fig,ax = plt.subplots(1,3, figsize=(20,6))
        
        for prob in prob_list:
            ax[0].plot(list(range(len(prob))), prob)

        ax[0].legend(group_names)
        ax[0].set_title("Categorical Probabilities")

        ax[1].plot(list(range(len(mean_regularizer))), mean_regularizer)
        ax[1].set_title("Regularizer Value")
        for i in range(overall_reg_group.shape[1]):
            ax[2].plot(list(range(overall_reg_group.shape[0])), overall_reg_group[:, i].to('cpu'))
        ax[2].set_title("Regularizer per group")
        ax[2].legend(group_names)

        writer.add_figure('Train_Figs/Training_Figures', fig, global_step)
    
    correct=0
    correct_augm=0
    number_samples=0
    net.eval()
    with torch.no_grad():
        for batch_id, (data,target,metadata) in enumerate(tqdm(test_loader)):
            b_size=data.shape[0]
            if(b_size!=trainer.opt.test_size):
                continue
            if group_list is not None:
                indexes=torch.nonzero(sum(metadata[:,0]==group for group in group_list))
                print(indexes)
                if indexes.shape[0]==0:
                    continue
            
                indexes=indexes.flatten()
                data=data[indexes,:,:,:]
                target=target[indexes]
                
                b_size=data.shape[0]
            data, target= data.to(trainer.device), target.to(trainer.device)
            batch_size=data.shape[0]
            output=F.log_softmax(net(data),dim=1)
            pred=output.argmax(dim=1,keepdim=True)
            correct+=pred.eq(target.view_as(pred)).sum().item()
            number_samples+=pred.shape[0]
            
            net(data[0].unsqueeze(0))
            data_sample, _,_=sampler(test_augm_samples*b_size)
            data_sample = data_sample.detach()
            #print("Data Sample:",data_sample)       
            data_augm=action(data_sample,data)
            #pre_input shape (number_samples,batch_size,3)
            #input=pre_input.reshape(number_samples*batch_size,-1)
            data_augm = data_augm.flatten(start_dim=0,end_dim=1)
            output_augm = F.log_softmax(net(data_augm),dim=1)
            out_augm_prob = torch.exp(output_augm) #Assumes LogSoftmax Output
            mean_augm_prob = out_augm_prob.reshape(test_augm_samples, b_size, -1).mean(dim=0).squeeze(0)
            
            if(len(mean_augm_prob.shape)==1):
                mean_augm_prob=mean_augm_prob.unsqueeze(0)
            pred = mean_augm_prob.argmax(dim=1, keepdim=True)
            
            correct_augm += pred.eq(target.view_as(pred)).sum().item()
            
    net.train()
    
    if writer is not None:
        writer.add_scalar('Accuracy/'+str(trainer.opt.dataset)+'_Test_Accuracy',correct/number_samples,global_step)
        writer.add_scalar('Accuracy/'+str(trainer.opt.dataset)+'_TestAugm_Accuracy', correct_augm/number_samples,global_step)

    print("Test Accuracy="+str(correct/number_samples))
    print("TestAugm_Accuracy="+str(correct_augm/number_samples))
    plt.close('all')



def visualize3D(trainer, writer, global_step, net, dataset, test_loader, prob_list, group_names, mean_regularizer,
                overall_reg_group):
    if writer is None:
        return
    grid_size = trainer.opt.grid_size
    grid_range = trainer.opt.grid_range
    grid1, grid2 = torch.meshgrid(torch.linspace(-grid_range, grid_range, grid_size),
                                  torch.linspace(-grid_range, grid_range, grid_size))

    dummy_ones = torch.ones(grid1.shape[0] * grid1.shape[1], 1)
    input_gridxy = torch.cat([grid1.reshape(-1, 1), grid2.reshape(-1, 1), dummy_ones * 0], dim=1)
    input_gridyz = torch.cat([dummy_ones * 0, grid1.reshape(-1, 1), grid2.reshape(-1, 1)], dim=1)
    input_gridzx = torch.cat([grid2.reshape(-1, 1), dummy_ones * 0, grid1.reshape(-1, 1)], dim=1)

    outxy = net(input_gridxy.to(trainer.device)).squeeze()
    out_meshxy = outxy.reshape(grid_size, grid_size).detach().to('cpu')

    outyz = net(input_gridyz.to(trainer.device)).squeeze()
    out_meshyz = outyz.reshape(grid_size, grid_size).detach().to('cpu')

    outzx = net(input_gridzx.to(trainer.device)).squeeze()
    out_meshzx = outzx.reshape(grid_size, grid_size).detach().to('cpu')
    fig, ax = plt.subplots(1, 6, figsize=(30, 6))

    tot_rotated_data, tot_y = dataset.tensors

    ax[0].contourf(grid1, grid2, out_meshxy)
    ax[0].scatter(tot_rotated_data[:, 0].to('cpu').numpy(), tot_rotated_data[:, 1].to('cpu').numpy(),
                  c=tot_y.to('cpu').numpy(), vmin=0, vmax=1)
    # ax[0].scatter(input[:,0].detach().to('cpu').numpy(),input[:,1].detach().to('cpu').numpy(),c=y_g.view(-1).to('cpu').numpy(),vmin=0,vmax=1)
    ax[0].set_title("Classifier Boundaries XY")

    ax[1].contourf(grid1, grid2, out_meshyz)
    mask_yellow = (tot_y == 1).reshape(-1)
    ax[1].scatter(tot_rotated_data[:, 1].to('cpu').numpy(), tot_rotated_data[:, 2].to('cpu').numpy(),
                  c=tot_y.to('cpu').numpy(), vmin=0, vmax=1)
    # ax[0].scatter(input[:,0].detach().to('cpu').numpy(),input[:,1].detach().to('cpu').numpy(),c=y_g.view(-1).to('cpu').numpy(),vmin=0,vmax=1)
    ax[1].set_title("Classifier Boundaries YZ")

    ax[2].contourf(grid2, grid1, out_meshzx)
    ax[2].scatter(tot_rotated_data[:, 0].to('cpu').numpy(), tot_rotated_data[:, 2].to('cpu').numpy(),
                  c=tot_y[:].to('cpu').numpy(), vmin=0, vmax=1)
    # ax[0].scatter(input[:,0].detach().to('cpu').numpy(),input[:,1].detach().to('cpu').numpy(),c=y_g.view(-1).to('cpu').numpy(),vmin=0,vmax=1)
    ax[2].set_title("Classifier Boundaries ZX")

    for prob in prob_list:
        ax[3].plot(list(range(len(prob))), prob)

    ax[3].legend(group_names)
    ax[3].set_title("Categorical Probabilities")

    ax[4].plot(list(range(len(mean_regularizer))), mean_regularizer)
    ax[4].set_title("Regularizer Value")
    for i in range(overall_reg_group.shape[1]):
        ax[5].plot(list(range(overall_reg_group.shape[0])), overall_reg_group[:, i].to('cpu'))
    ax[5].set_title("Regularizer per group")
    ax[5].legend(group_names)

    writer.add_figure('Training_Figures', fig, global_step)
    data_thetas = torch.zeros(trainer.opt.test_size, 3, device=trainer.device)
    with torch.no_grad():
        for epoch in trange(3):
            rotated_data, y = iter(test_loader).next()

            input = rotated_data
            out = net(input).squeeze()

            y_g = y.reshape(-1)
            fig, ax = plt.subplots(1, 3, figsize=(20, 6))

            ax[0].scatter(input[:, 0].to('cpu').numpy(), input[:, 1].to('cpu').numpy(), c=out.to('cpu').numpy(),
                          vmin=0, vmax=1)
            ax[0].set_title("Test Classifier XY")

            ax[1].scatter(input[:, 1].to('cpu').numpy(), input[:, 2].to('cpu').numpy(), c=out.to('cpu').numpy(),
                          vmin=0, vmax=1)
            ax[1].set_title("Test Classifier YZ")

            ax[2].scatter(input[:, 0].to('cpu').numpy(), input[:, 2].to('cpu').numpy(), c=out.to('cpu').numpy(),
                          vmin=0, vmax=1)
            ax[2].set_title("Test Classifier ZX")

    writer.add_figure('Test_Figures', fig, global_step=global_step)
    plt.close('all')
