import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class FormDataset(Dataset):
    """
    Standerd Pytoch Dataset Constructor
    """
    def __init__(self, cases, window=12, past_f=5):        
        self.cases = cases
        self.window = window
        self.past_f = past_f

    def __len__(self):
        return len(self.cases)

    def __getitem__(self, idx):
        window=self.window
        if torch.is_tensor(idx):
            idx = idx.tolist()
        history=self.cases[idx,:-window-1,:self.past_f]
        future_cov=self.cases[idx,-window-1:-1,1:]
        starting_cgm=self.cases[idx,-window-1:-window,0]
        output_cgm=self.cases[idx,-window:,0:1]
        
        sample = (history,starting_cgm,future_cov,\
                  output_cgm)
        return sample
    

def cv_split2(perms,cases,rep,train_fold,val_fold,test_fold,N,V=4,batch_size=64,standardize=True):
#split and form data loarder
    perm=perms[rep]
    cases_s=cases[perm]
    
    n=len(cases_s)
    sz=n//N
    val_sz=sz//V
    
    if test_fold<N-1:
        test_cases=cases_s[test_fold*sz:(test_fold+1)*sz]
    else:
        test_cases=cases_s[test_fold*sz:]
        
    if train_fold<N-1:
        tv_cases=cases_s[train_fold*sz:(train_fold+1)*sz]
    else:
        tv_cases=cases_s[train_fold*sz:]
        
    if val_fold<V-1:
        val_cases=tv_cases[val_fold*val_sz:(val_fold+1)*val_sz]
        train_cases=np.concatenate([tv_cases[:val_fold*val_sz],tv_cases[(val_fold+1)*val_sz:]],axis=0)
    else:
        val_cases=tv_cases[val_fold*val_sz:]
        train_cases=tv_cases[:val_fold*val_sz]

    train_mean=np.mean(train_cases,axis=(0,1))
    train_std=np.std(train_cases,axis=(0,1))
    train_cases=np.divide(train_cases-train_mean,train_std)
    val_cases=np.divide(val_cases-train_mean,train_std)
    test_cases=np.divide(test_cases-train_mean,train_std)
    
    train=DataLoader(FormDataset(train_cases),\
                                   batch_size=batch_size)
    val=DataLoader(FormDataset(val_cases),\
                                   batch_size=len(val_cases))
    test=DataLoader(FormDataset(test_cases),\
                                   batch_size=len(test_cases))

    return train,val,test,train_mean,train_std

def train_model_greedy(model, train, val, epochs, \
                hyper_params, train_std, \
                device, verbose, path):
    best_perf=1e7
    EW=model.init_ew()
    remaining_edges=[i for i in range(len(EW))]
    while True:
        perfs=[]
        for e in remaining_edges:
            EW[e]=0
            model=type(model)(DAG=model.dag, edge_map=model.edge_map, hyper_params=model.hp)
            model.set_ew(nn.Parameter(EW, requires_grad=False))
            model.freeze_ew()
            _,val_loss=train_model(model, train, val, 30, \
                hyper_params, train_std, \
                device=device, verbose=False, \
                path=None, r_mode='NR')
            perfs.append(np.min(val_loss))
            EW[e]=1
        if np.min(perfs)>best_perf:
            break
        else:
            e_to_remove=remaining_edges.pop(np.argmin(perfs))
            EW[e_to_remove]=0
            best_perf=np.min(perfs)
    model=type(model)(DAG=model.dag, edge_map=model.edge_map, hyper_params=model.hp)
    model.set_ew(nn.Parameter(EW, requires_grad=False))
    model.freeze_ew()
    #print(model.return_edge_weights())
    return train_model(model, train, val, epochs, \
                hyper_params, train_std, \
                device=device, verbose=verbose, \
                path=path, r_mode='NR')
    
def train_model_random(model, train, val, epochs, \
                hyper_params, train_std, \
                device, verbose, path):
    best_perf=1e7
    rng=np.random.default_rng(2024)
    EW=model.init_ew()
    edge_set=[i for i in range(len(EW))]
    random_subsets=[rng.choice(edge_set,size=int(len(EW)*hyper_params['p']),replace=False) for i in range(5)]
    perfs=[]
    for subset in random_subsets:
        EW[subset]=0
        model=type(model)(DAG=model.dag, edge_map=model.edge_map, hyper_params=model.hp)
        model.set_ew(nn.Parameter(EW, requires_grad=False))
        model.freeze_ew()
        #print(model.return_edge_weights())
        _, val_losses=train_model(model, train, val, epochs, \
                hyper_params, train_std, \
                device=device, verbose=False, \
                path=None, r_mode='NR')
        perfs.append(np.min(val_losses))
        EW[subset]=1
    best_subset=random_subsets[np.argmin(perfs)]
    EW[best_subset]=0
    model=type(model)(DAG=model.dag, edge_map=model.edge_map, hyper_params=model.hp)
    model.set_ew(nn.Parameter(EW, requires_grad=False))
    model.freeze_ew()
    #print(model.return_edge_weights())
    return train_model(model, train, val, epochs, \
                hyper_params, train_std, \
                device=device, verbose=verbose, \
                path=path, r_mode='NR')

def train_model_transformer(model,train,val,epochs,hyper_params,train_std,\
                            device=None, verbose=False, path=None):
    #train/validate model with train and val, save to path
    if (device):
        model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params['lr'])
    loss_fn1=nn.MSELoss()
    train_losses=[]
    val_losses=[]
    best_val=1e7
    for epoch in range(epochs):
        model.train()
        train_loss=0
        for batch, (past,y0,x,y) in enumerate(train):
            if device:
                past=past.to(device)
                y0=y0.to(device)
                x=x.to(device)
                y=y.to(device)
            preds=[]
            causal_mask=nn.Transformer.generate_square_subsequent_mask(x.shape[-2],device=device)
            src_x=nn.functional.pad(x[:,:,:-1],(1,0),'constant',0)
            src_x[:,0,0:1]=y0
            src=torch.concat([past,src_x],dim=1)
            tgt=torch.concat([y0.unsqueeze(dim=-1),y[:,:-1]],dim=1)
            pred=model(src,tgt,causal_mask)
            loss = loss_fn1(pred, y)
            if verbose and (epoch-1)%1==0:
                print(f"rmse:{np.sqrt(loss.item())*train_std[0]}")
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()
            train_loss+=loss.item()*len(y)
            
        train_size=len(train.dataset)
        train_losses.append(train_loss/train_size)
        model.eval()
        with torch.no_grad():
            for batch, (past,y0,x,y) in enumerate(val):
                if device:
                    past=past.to(device)
                    y0=y0.to(device)
                    x=x.to(device)
                    y=y.to(device)

                src_x=nn.functional.pad(x[:,:,:-1],(1,0),'constant',0)
                src_x[:,0,0:1]=y0
                src=torch.concat([past,src_x],dim=1)
                tgt=torch.concat([y0.unsqueeze(dim=-1),y[:,:-1]*0],dim=1)
                prediction=model(src,tgt)
                for j in range(1,x.shape[-2]):
                    tgt[:,j]=prediction[:,j-1]
                    prediction=model(src,tgt)
                
                loss_val = loss_fn1(prediction, y)
                valid_loss=loss_val.item()
                val_losses.append(valid_loss)
                if valid_loss<best_val and path:
                    best_val=valid_loss
                    torch.save(model.state_dict(),path)
                if verbose:
                    print(f"validation loss at epoch {epoch} pred rmse:{np.sqrt(valid_loss)*train_std[0]}")
    return train_losses, val_losses
    
def train_model(model, train, val, epochs, \
                hyper_params, train_std, \
                device=None, verbose=False, \
                path=None, r_mode=None):
    #train/validate model with train and val, save to path
    if (device):
        model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params['lr'])
    mse_loss=nn.MSELoss()
    train_losses=[]
    val_losses=[]
    best_val=1e7
    torch.autograd.set_detect_anomaly(True)
    if r_mode=='GD':
        #perform greedy reduction 
        print("running greedy")
        return train_model_greedy(model, train, val, epochs, \
                hyper_params, train_std, \
                device, verbose, path)
    elif r_mode=='RD':
        return train_model_random(model, train, val, epochs, \
                hyper_params, train_std, \
                device, verbose, path)
    elif r_mode=='TS':
        return train_model_transformer(model, train, val, epochs, \
                hyper_params, train_std, \
                device, verbose, path)
    for epoch in range(epochs):
        model.train()
        train_loss=0
        for batch, (past,y0,x,y) in enumerate(train):
            past=past.to(device)
            y0=y0.to(device)
            x=x.to(device)
            y=y.to(device)
            pred=model(past,y0,x)
            if hasattr(model, 'return_edge_weights'):
                edge_weights=model.return_edge_weights()
                edge_1norm=torch.sum(edge_weights) #edge weights have already been absolute valued
                edge_2norm=torch.sum(torch.square(edge_weights))
            else:
                edge_weights=0
                edge_1norm=0
                edge_2norm=0
            
            mlp_2norm=0
            for name,param in model.named_parameters():
                if name[:9]=="decoder.f":
                    mlp_2norm+=torch.sum(torch.square(param))
            if r_mode=='EN':
                loss=mse_loss(pred,y)+hyper_params['a2']*edge_2norm+hyper_params['a1']*edge_1norm
            elif r_mode=='GL':
                loss=mse_loss(pred,y)+hyper_params['a2']*mlp_2norm+hyper_params['a1']*edge_1norm
            elif r_mode=='EGL':
                edge_map=model.return_edge_map()
                egl_penalty=0
                for node in edge_map:
                    egl_penalty+=torch.square(torch.sum(edge_weights[node[0]]))
                loss=mse_loss(pred,y)+hyper_params['a1']*egl_penalty
            elif r_mode in ['NR','GD','RD','DK','NS']:
                loss=mse_loss(pred,y)+hyper_params['a2']*mlp_2norm
            else:
                print("Unrecognized Reduction")
                assert(0)
            
            if verbose and (epoch-1)%1==0:
                print(f"rmse:{torch.sqrt(mse_loss(pred,y))*train_std[0]} MSE:{mse_loss(pred,y)} mlp_norm:{mlp_2norm}")
                if hasattr(model, 'return_edge_weights'):
                    edge_w=model.return_edge_weights()
                    print(torch.round(edge_w,decimals=4))
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()
            train_loss+=loss.item()*len(x)
        train_size=len(train.dataset)
        train_losses.append(train_loss/train_size)

        model.eval()
        with torch.no_grad():
            for batch, (past,y0,x,y) in enumerate(val):
                past=past.to(device)
                y0=y0.to(device)
                x=x.to(device)
                y=y.to(device)
                pred=model(past,y0,x)
                loss_val = mse_loss(pred,y)
                valid_loss = loss_val.item()
                val_losses.append(valid_loss)
                if valid_loss<best_val and path:
                    best_val=valid_loss
                    torch.save(model.state_dict(),path)
                if (verbose) and (epoch%10==0):
                    print(f"validation loss at epoch:{epoch} rmse:{torch.sqrt(mse_loss(pred,y))*train_std[0]}")
          
    return train_losses, val_losses




        
    
