import torch 
import torch.nn as nn

from torch.autograd import Variable
from copy import deepcopy

from torch.nn import MSELoss as MSE
from torch.utils.data import Dataset, DataLoader

import numpy as np
from model import Wassertein_Loss
wass = Wassertein_Loss()

# A Modified Version of Symmetric Task Distance

def diag_fisher(model, input_data,device='cpu',batch=200, alpha=0.0): 
    '''
    model: the model trained on the source task
    data: the data to estimate the Fisher matrix for (it can be either the source data or the target data)
    '''
    #model is from base task, data is from target task.
    #data: dataloader form.
    precision_matrices = {}
    params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    for n, p in deepcopy(params).items():
        p.data.zero_()
        precision_matrices[n] = variable(p.data,device=device)

    model.eval()

    mse = MSE()
    # loss function during testing/computing the Fisher Information matrices
    #print(input_data['x'])
    X = torch.tensor(input_data['x'],dtype=torch.float).to(device)
    t = torch.tensor(input_data['t'],dtype=torch.float).to(device)
    y = torch.tensor(input_data['yf'],dtype=torch.float).to(device)
    D = torch.cat((X,t.reshape(len(t),1),y.reshape(len(y),1)),dim=1).to(device)
    
    data = DataLoader(D,batch_size=batch, shuffle=False)
    dim = X.shape[1]
    #print(dim)
    #print(D.shape)
    for tr in data:
        #print('hey')    
        train_X = tr[:,0:dim]
        #print(train_X.shape)
        train_t = tr[:,dim]
        train_y = tr[:,dim+1:dim+2]
        #train_ycf = tr[:,27]
        train_Y0 = train_y[train_t==0]
        #train_Y0 = train_Y0.reshape(len(train_Y0),1)
        train_Y1 = train_y[train_t==1]
        #train_Y1 = train_Y1.reshape(len(train_Y1),1)
        phi, y0, y1 = model(train_X)
        #input, treatment, label = input.cuda(), treatment.cuda(), label.cuda()
        model.zero_grad()

        # get the output from the model
        _, y0, y1 = model(train_X)
        #print('hey')
        #print(y0)
        #print(y1)

        # loss function only contains MSE loss from Y0 and Y1
        loss = mse(y0[train_t==0],train_Y0) + mse(y1[train_t==1],train_Y1) + alpha*wass(phi[train_t==0],phi[train_t==1])# \alpha * Wass
        #print('loss')
        #print(loss)
        loss.backward()

        # Compute the Fisher Information as (p.grad.data ** 2)
        for n, p in model.named_parameters():
            precision_matrices[n].data += (p.grad.data ** 2).mean(0)

    # Fisher Information of the whole TARNET
    precision_matrices = {n: p for n, p in precision_matrices.items()}
    
    return precision_matrices, loss.item()

def inverted_diag_fisher(model, input_data,device='cpu',batch=200,alpha=0.0): 
    '''
    The function to compute the diagonal Fisher Information using data samples with inverted treatment
    model: the model trained on the source task
    data: the data to estimate the Fisher matrix for (it can be either the source data or the target data)
    '''
    #model is from base task, data is from target task.
    #data: dataloader form.
    precision_matrices = {}
    params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    for n, p in deepcopy(params).items():
        p.data.zero_()
        precision_matrices[n] = variable(p.data,device=device)

    model.eval()

    mse = MSE()
    # loss function during testing/computing the Fisher Information matrices
    #print(input_data['x'])
    X = torch.tensor(input_data['x'],dtype=torch.float).to(device)
    t = torch.tensor(input_data['t'],dtype=torch.float).to(device)
    y = torch.tensor(input_data['yf'],dtype=torch.float).to(device)
    D = torch.cat((X,t.reshape(len(t),1),y.reshape(len(y),1)),dim=1).to(device)
    
    data = DataLoader(D,batch_size=batch, shuffle=False)
    dim = X.shape[1]
    #print(dim)
    #print(D.shape)
    for tr in data:
        train_X = tr[:,0:dim]
        train_t = tr[:,dim]
        train_y = tr[:,dim+1:dim+2]
        
        # train_Y0 = train_y[train_t==0]
        # train_Y1 = train_y[train_t==1]
        train_Y0 = train_y[train_t==0]
        train_Y1 = train_y[train_t==1]

        phi, y0, y1 = model(train_X)
        #input, treatment, label = input.cuda(), treatment.cuda(), label.cuda()
        model.zero_grad()

        # get the output from the model
        _, y0, y1 = model(train_X)

        # loss function only contains MSE loss from Y0 and Y1
        # loss = mse(y0[train_t==0],train_Y1) + mse(y0[train_t==1],train_Y0) # \alpha * Wass
        loss = mse(y0[train_t==0],train_Y0) + mse(y1[train_t==1],train_Y1) + alpha*wass(phi[train_t==0],phi[train_t==1])

        loss.backward()

        # Compute the Fisher Information as (p.grad.data ** 2)
        for n, p in model.named_parameters():
            precision_matrices[n].data += (p.grad.data ** 2).mean(0)

    # Fisher Information of the whole TARNET
    precision_matrices = {n: p for n, p in precision_matrices.items()}
    return precision_matrices, loss.item()

def Fisher_distance(model,fisher_matrix_source, fisher_matrix_target):
    '''
    model: the trained model
    fisher_matrix_source: the fisher matrix for the source task
    fisher_matrix_target: the fisher matrix for the target task
    '''
    distance = 0
    for n, p in model.named_parameters():
        #print(fisher_matrix_source[n]**0.5)
        # distance += 0.5 * np.sum(((fisher_matrix_source[n]**0.5 - fisher_matrix_target[n]**0.5)**2).cpu().numpy())
        distance += 0.5 * np.mean(((fisher_matrix_source[n]**0.5 - fisher_matrix_target[n]**0.5)**2).cpu().numpy())
    return distance

def compute_distance(model,source_task, target_task, device, alpha=0.0):
    '''
    Main function used to compute the Fisher distance
    '''
    # diagonal fisher matrices
    fisher_source, _  = diag_fisher(model, source_task, device=device,alpha=0.0)
    fisher_target, loss_val_1 = diag_fisher(model, target_task, device=device,alpha=0.0)
    # diagonal fisher matrices with inverted treatment
    invert_fisher_source, _ = inverted_diag_fisher(model, source_task, device=device,alpha=0.0)
    invert_fisher_target, loss_val_2 = inverted_diag_fisher(model, target_task, device=device,alpha=0.0)

    distance_1_1 = Fisher_distance(model,fisher_source, fisher_target)
    distance_2_1 = Fisher_distance(model,invert_fisher_source, invert_fisher_target)


    distance_1_2 = Fisher_distance(model,fisher_source, fisher_target)
    distance_2_2 = Fisher_distance(model,fisher_source, invert_fisher_target)

    # Return the minimum distance
    distance_1 = min(distance_1_1, distance_2_1)
    distance_2 = min(distance_1_2, distance_2_2)
    # in future application, the flag or argmin might be useful

    # 1.fisher
    # 2.inverted_fisher
    # 3. fisher if fisher_loss < inverted_fisher_loss else inverted_fisher
    
    return (distance_1,distance_2), (distance_1_1, loss_val_1,distance_1_2), (distance_2_1, loss_val_2, distance_2_2)

def variable(t: torch.Tensor, device='cpu',use_cuda=True, **kwargs):
    #if torch.cuda.is_available() and use_cuda:
    #    t = t.cuda()
    t = t.to(device) 
    return Variable(t, **kwargs)

def Rep_distance(source_model, source_task, target_task, loss_func, wass, device, wc=1, wt=1, alpha=3, updateEpoch=1, lr=1e-5, mode = 'l2'):
  
    source_model_ori_copy = deepcopy(source_model.state_dict())
    source_model.to(device)

    source_T = torch.Tensor(source_task['t']).to(device)
    target_T = torch.Tensor(target_task['t']).to(device)
    source_input = torch.Tensor(source_task['x']).to(device)
    target_input = torch.Tensor(target_task['x']).to(device)
    target_output = torch.Tensor(target_task['yf']).reshape(-1,1).to(device)

    ori_phi, y0, y1 = source_model(source_input)  

    optimizer = torch.optim.Adam(source_model.parameters(), lr=lr)
    tr_Y0 = target_output[target_T==0]
    tr_Y1 = target_output[target_T==1]
    for _ in range(updateEpoch):
        optimizer.zero_grad()  
        phi, y0, y1 = source_model(target_input)

        loss = wc * loss_func(y0[target_T==0],tr_Y0) + wt * loss_func(y1[target_T==1],tr_Y1) \
        + alpha*wass(phi[target_T==0],phi[target_T==1])

        loss.backward()
        optimizer.step()

    new_phi, y0, y1 = source_model(source_input)  
    source_model.load_state_dict(source_model_ori_copy)
    restored_phi, y0, y1 = source_model(source_input)

    if mode == 'l2':
        return torch.sqrt(torch.sum((ori_phi-new_phi)**2)).item()

    if mode == 'l1':
        return torch.mean(torch.abs(ori_phi-new_phi)).item()

    if mode == 'wass':
        return wass(ori_phi, new_phi).item()

def IPM_distance(source_model,source_task, target_task,IPM, device):

  source_T = torch.Tensor(source_task['t']).to(device)
  target_T = torch.Tensor(target_task['t']).to(device)
  source_input = torch.Tensor(source_task['x']).to(device)
  target_input = torch.Tensor(target_task['x']).to(device)
  # target_input.to(device)
  phi,y0,y1 = source_model(target_input)
  return IPM(source_input[source_T == 0],target_input[target_T==0]) + \
  IPM(source_input[source_T == 1],target_input[target_T==1]) + \
  IPM(phi[target_T==0],phi[target_T==1])


