
from utils import DatasetWrapper,get_single_grad_vector,free_grad_memory

from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import torch
import torch.nn as nn
import random
import numpy as np

class Poison(nn.Module):
    def __init__(self,n_channels,signal_size,poison_num_samples,inf_norm_limit) -> None:
        super(Poison,self).__init__()
        self.signal_size = signal_size
        self.n_channels = n_channels
        self.inf_norm_limit = inf_norm_limit
        self.poison_num_samples = poison_num_samples
        self.delta = nn.Parameter(torch.ones((self.poison_num_samples,self.n_channels,self.signal_size)))
        nn.init.uniform_(self.delta,-inf_norm_limit,inf_norm_limit)

    def forward(self,x,index_mask=None):
       if index_mask is None:
           delta = self.delta
       else:
           delta = self.delta[index_mask]
       zero_mask = x !=0.
       stamped_x = x + delta
       stamped_x = stamped_x*zero_mask
       return stamped_x 

    def project_parameters(self):
        with torch.no_grad():
            for n,p in self.named_parameters():
                if n =="delta":
                    p.clamp_(-1*self.inf_norm_limit,self.inf_norm_limit)

class ImagePoison(nn.Module):
    def __init__(self,n_channels,image_size,poison_num_samples,inf_norm_limit) -> None:
        super(ImagePoison,self).__init__()
        self.image_size = image_size
        self.n_channels = n_channels
        self.inf_norm_limit = inf_norm_limit
        self.poison_num_samples = poison_num_samples
        self.delta = nn.Parameter(torch.ones((self.poison_num_samples,self.n_channels,*self.image_size)))
        nn.init.uniform_(self.delta,-inf_norm_limit,inf_norm_limit)

    def forward(self,x,index_mask=None):
       if index_mask is None:
           delta = self.delta
       else:
           delta = self.delta[index_mask]

       stamped_x = x + delta
       return stamped_x 

    def project_parameters(self):
        with torch.no_grad():
            for n,p in self.named_parameters():
                if n =="delta":
                    p.clamp_(-1*self.inf_norm_limit,self.inf_norm_limit)


class WitchesBrew(object):
    def __init__(self,train_dataset,test_dataset,neural_model,attack_config,input_type=None,gradient_matching_type= None) -> None:
        super().__init__()
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.attack_config = attack_config
        self.neural_model = neural_model
        self.gradient_matching_type = gradient_matching_type


        self.poison_num_samples=attack_config['poison_num_samples']
        self.poison_lrs = attack_config['poison_lrs']
        self.n_epochs = attack_config['n_epochs']
        self.inf_norm_limit = attack_config['inf_norm_limit']
        self.target_class = attack_config['target_class']
        self.labels = attack_config['labels']

        self.input_type = input_type if input_type else "1D"
    
    def initialize_evil_experiment(self,victim_label,victim_patient,poison_class = 0):
        
        samples_to_poison_idx= self.train_dataset.get_sample_indexes(self.poison_num_samples,selected_label =poison_class)
        victim_index = self.test_dataset.get_victim_sample(victim_label=victim_label,victim_patient=victim_patient)
        # print(victim_index)

        return victim_index,samples_to_poison_idx

    def generate_poisons(self,victim_index,samples_to_poison_idx,device,restarts):

        ret_val_list = []
        for idx in range(restarts):
            ret_val = self.__generate_poison(victim_index,samples_to_poison_idx,device)
            ret_val_list.append(ret_val)
        ret_val_list.sort(key=lambda x:x[2])

        #print([_[2] for _ in ret_val_list])
        return ret_val_list[0]
    
    def __generate_poison(self,victim_index,samples_to_poison_idx,device):
        if self.gradient_matching_type is None:
            return self.__generate_combined_poison(victim_index,samples_to_poison_idx,device)
        elif self.gradient_matching_type =="seperated" or self.gradient_matching_type == "seperated_projection":
            return self.__generate_seperated_poison(victim_index,samples_to_poison_idx,device)
        

    def __generate_combined_poison(self,victim_index,samples_to_poison_idx,device):

        dta = self.test_dataset[victim_index]
        if len(dta)==2:
            victim_X,victim_Y = dta
        elif len(dta)==3:
            victim_X,victim_Y, victim_patient = dta

        poison_lr = random.choice(self.poison_lrs)
        print(f'\t\tUsing attack learning rate :{poison_lr}')
        self.neural_model.eval()
        self.neural_model.to(device)
        self.neural_model.zero_grad()

        target_class = torch.tensor([self.target_class],dtype = torch.long,device=device)
        victim_X_tensor = victim_X.unsqueeze(0).to(device)

        criterion = nn.CrossEntropyLoss()
        victim_out = self.neural_model(victim_X_tensor)
        loss = criterion(victim_out,target_class)
        loss.backward()
        victim_grad = get_single_grad_vector(self.neural_model).detach()

        if self.input_type =='1D':
            poison_model = Poison(n_channels=victim_X.shape[0],signal_size=victim_X.shape[1],
                                                        poison_num_samples = self.poison_num_samples,
                                                        inf_norm_limit=self.inf_norm_limit)
        elif self.input_type =='2D':
            poison_model = ImagePoison(n_channels=victim_X.shape[0],
                                       image_size=victim_X.shape[1:],
                                       poison_num_samples=self.poison_num_samples,
                                       inf_norm_limit=self.inf_norm_limit)

        poison_model.to(device)
        samples_to_poison = [self.train_dataset[idx] for idx in samples_to_poison_idx]

        if samples_to_poison[0].__len__()==3:
            poison_sample_dataset = DatasetWrapper(torch.stack([_[0] for _ in samples_to_poison],dim=0),
                                                    torch.stack([_[1] for _ in samples_to_poison],dim=0),
                                                    [_[2] for _ in samples_to_poison])
        else:
            #NO metadata case
             poison_sample_dataset = DatasetWrapper(torch.stack([_[0] for _ in samples_to_poison],dim=0),
                                                    torch.stack([_[1] for _ in samples_to_poison],dim=0))

        poison_dataloader = DataLoader(poison_sample_dataset,batch_size=self.poison_num_samples+1,shuffle=False)
        optimizer = optim.Adam(poison_model.parameters(),lr=poison_lr,weight_decay=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        for epoch in range(self.n_epochs):
            for dta in poison_dataloader:
                if dta.__len__() ==3:
                    x,y,m = dta
                else:
                    x,y = dta
                    m = None
                x = x.to(device = device,dtype = torch.float)
                y = y.to(device = device, dtype = torch.long)
                self.neural_model.zero_grad()
                optimizer.zero_grad()
                stamped_x = poison_model(x)
                outputs = self.neural_model(stamped_x)
                loss = criterion(outputs,y)
                loss.backward(create_graph=True)

                grad_vals = get_single_grad_vector(self.neural_model)

                inner_product = torch.sum(victim_grad * grad_vals)/ (torch.norm(victim_grad,2)*torch.norm(grad_vals,2))
                loss = 1-inner_product

                if epoch % 99 ==0:
                    print(f'\t\t\tEpoch {epoch}: Loss {loss}, Inner Product {inner_product}')
                optimizer.zero_grad()
                #print(get_single_grad_vector(poison_model).sum())
                loss.backward()
                #print(get_single_grad_vector(poison_model).sum())

                optimizer.step()
                poison_model.project_parameters()
                scheduler.step()
    
        # Freeing all grad values to reduce memory leak
        free_grad_memory(self.neural_model)
        free_grad_memory(poison_model)

        ret_val_detached = (stamped_x.detach().cpu(),m,loss.detach().cpu().numpy())
        return ret_val_detached
    


    
    def __generate_seperated_poison(self,victim_index,samples_to_poison_idx,device):

        dta = self.test_dataset[victim_index]
        if len(dta)==2:
            victim_X,victim_Y = dta
        elif len(dta)==3:
            victim_X,victim_Y, victim_patient = dta
        poison_lr = random.choice(self.poison_lrs)
        print(f'\t\tUsing attack learning rate :{poison_lr}')
        self.neural_model.eval()
        self.neural_model.to(device)
        self.neural_model.zero_grad()

        target_class = torch.tensor([self.target_class],dtype = torch.long,device=device)
        victim_X_tensor = victim_X.unsqueeze(0).to(device)

        if self.gradient_matching_type =='seperated':
            criterion = nn.CrossEntropyLoss()
        elif self.gradient_matching_type =='seperated_projection':
            criterion = nn.CrossEntropyLoss(reduction='none')
        else:
            assert False, f"Invalid Mode {self.gradient_matching_type}"

        victim_out = self.neural_model(victim_X_tensor)
        loss = criterion(victim_out,target_class)
        loss.backward()
        victim_grad = get_single_grad_vector(self.neural_model).detach()

        if self.input_type =='1D':
            poison_model = Poison(n_channels=victim_X.shape[0],signal_size=victim_X.shape[1],
                                                        poison_num_samples = self.poison_num_samples,
                                                        inf_norm_limit=self.inf_norm_limit)
        elif self.input_type =='2D':
            poison_model = ImagePoison(n_channels=victim_X.shape[0],
                                       image_size=victim_X.shape[1:],
                                       poison_num_samples=self.poison_num_samples,
                                       inf_norm_limit=self.inf_norm_limit)

        poison_model.to(device)
        samples_to_poison = [self.train_dataset[idx] for idx in samples_to_poison_idx]

        if samples_to_poison[0].__len__()==3:
            poison_sample_dataset = DatasetWrapper(torch.stack([_[0] for _ in samples_to_poison],dim=0),
                                                    torch.stack([_[1] for _ in samples_to_poison],dim=0),
                                                    [_[2] for _ in samples_to_poison])
        else:
            #NO metadata case
             poison_sample_dataset = DatasetWrapper(torch.stack([_[0] for _ in samples_to_poison],dim=0),
                                                    torch.stack([_[1] for _ in samples_to_poison],dim=0))
        
        poison_dataloader = DataLoader(poison_sample_dataset,batch_size=self.poison_num_samples+1,shuffle=False)
        optimizer = optim.Adam(poison_model.parameters(),lr=poison_lr,weight_decay=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
        

        for epoch in range(self.n_epochs):
            for dta in poison_dataloader:
                if dta.__len__() ==3:
                    x,y,m = dta
                else:
                    x,y = dta
                    m = None
                eval_poison_losses = None
                inner_products = []
                x = x.to(device = device,dtype = torch.float)
                y = y.to(device = device, dtype = torch.long)
                if self.gradient_matching_type=='seperated':
                    for label in self.labels:
                        x_l = x[y==label]
                        y_l = y[y==label]
                        if (y==label).sum()==0:
                            continue
                        #import pdb;pdb.set_trace()
                        self.neural_model.zero_grad()
                        optimizer.zero_grad()
                        stamped_x = poison_model(x_l,y==label)
                        outputs = self.neural_model(stamped_x)
                        loss = criterion(outputs,y_l)
                        loss.backward(create_graph=True)

                        grad_vals = get_single_grad_vector(self.neural_model)

                        inner_product = torch.sum(victim_grad * grad_vals)/ (torch.norm(victim_grad,2)*torch.norm(grad_vals,2))
                        inner_products.append(inner_product.detach().cpu().numpy())
                        current_loss = (1-inner_product)

                        self.neural_model.zero_grad()
                        optimizer.zero_grad()
                        current_loss.backward()

                        optimizer.step()
                        poison_model.project_parameters()


                        if eval_poison_losses is None:
                            eval_poison_losses = current_loss.detach().cpu()
                        else:
                            eval_poison_losses += current_loss.detach().cpu()
                    if epoch % 99 ==0:
                        print(f'\t\t\tSeperated GM Epoch {epoch}: Loss {eval_poison_losses}, Inner Products mean(over classes): {np.mean(inner_products)}')
                    
                    # self.neural_model.zero_grad()
                    # optimizer.zero_grad()
                    # print(get_single_grad_vector(poison_model).sum())
                    # losses.backward()
                    #print(get_single_grad_vector(poison_model).sum())

                    # optimizer.step()
                    # poison_model.project_parameters()
                    scheduler.step()

                elif self.gradient_matching_type=='seperated_projection':
                    ce_loss = None
                    eval_poison_losses = None
                    x = x.to(device = device,dtype = torch.float)
                    y = y.to(device = device, dtype = torch.long)
                    self.neural_model.zero_grad()
                    optimizer.zero_grad()
                    stamped_x = poison_model(x)
                    outputs = self.neural_model(stamped_x)
                    nonreduced_loss = criterion(outputs,y)
                    for label in self.labels:
                        label_loss = torch.mean(nonreduced_loss[y==label])
                        if ce_loss is None:
                            ce_loss=label_loss
                        else:
                            ce_loss+=label_loss
                
                    ce_loss.backward(create_graph=True)

                    grad_vals = get_single_grad_vector(self.neural_model)

                    projection = torch.sum(victim_grad * grad_vals)/ (torch.norm(victim_grad,2))
                    poisoning_loss = -projection
                    eval_poison_losses = poisoning_loss
                    if epoch % 99 ==0:
                        print(f'\t\t\tEpoch {epoch}: Projection Poisoning Loss {poisoning_loss}')
                    
                    self.neural_model.zero_grad()
                    optimizer.zero_grad()
                    #print(get_single_grad_vector(poison_model).sum())
                    poisoning_loss.backward()
                    #print(get_single_grad_vector(poison_model).sum())

                    optimizer.step()
                    poison_model.project_parameters()
                    scheduler.step()
                else:
                    assert False, f"Invalid mode {self.gradient_matching_type}"
        # Freeing all grad values to reduce memory leak
        free_grad_memory(self.neural_model)
        free_grad_memory(poison_model)

        with torch.no_grad():
            for dta in poison_dataloader:
                if dta.__len__() ==3:
                    x,y,m = dta
                else:
                    x,y = dta
                    m = None
                x = x.to(device = device,dtype = torch.float)
                stamped_x = poison_model(x)
                y = y.to(device = device, dtype = torch.long)
                outputs = self.neural_model(stamped_x)
                loss = criterion(outputs,y)
        
        #import pdb;pdb.set_trace()
        ret_val_detached = (stamped_x.detach().cpu(),m,eval_poison_losses.detach().cpu().numpy())

        return ret_val_detached
    


