import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, Sampler
import random
from collections import defaultdict
import itertools
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch
import torch.optim as optim
from sklearn.decomposition import PCA
import warnings
from manifold import grow_manifolds_supervised,calculate_point_point_similarities #use grow_supervised for restricting manifold growth between similar classes
from loss import ManifoldPointToPointLoss,ProxyAnchorLoss
from dataset import CustomDataset #use balanced dataset for equi distribution

train_feat_path = "data/X_train.npy"
train_action_path = "data/a_train.npy"

X_train = np.load(train_feat_path)
a_train = np.load(train_action_path)

emb_dim = X_train.shape[1]



class NNWrapper(nn.Module):
    def __init__(self):
        super(NNWrapper, self).__init__()
        
        self.additional_layer = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.InstanceNorm1d(emb_dim),
            nn.ReLU()

           
        
        )

    def forward(self, x):
         
        x = self.additional_layer(x) 
        return x

model = NNWrapper() 


def train_supcon_model(model, train_loader,opt_mod,opt_prox, scheduler_mod,scheduler_prox, epochs=10, patience=5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    loss1 = ManifoldPointToPointLoss()
    # model = model.to(device)
    
    # criterion = EPHNLoss()
    # criterion = SupConLoss_new()
    # criterion = nn.TripletMarginLoss(margin= 0.0001, p=2)
    # criterion = ProxyNCALoss()
    best_loss = float('inf')
    best_epoch = 0
    epoch_losses = []  
    early_stop_counter = 0  

    for epoch in tqdm(range(epochs)):
        total_loss = 0.0
        total_sim_loss = 0.0
        total_metric_loss = 0.0
        model.train()
        
        for batch in train_loader:
          
            pos,lab = batch
            
         
          
            
      
            pos = pos.to(device).float()
          

 
            emb = model(pos)
            emb = emb.squeeze(1)
            lab= lab.to(device)
            manifold,basis,all_points = grow_manifolds_supervised(emb, lab, m=3, reconstruction_threshold=0.9, max_neighbors=20)
            sims = calculate_point_point_similarities(emb, manifold, basis, all_points, N_alpha=4, N_beta=0.5)
            
            sim_loss = loss1(emb,sims)
            # emb = emb.squeeze(1)

            
            

            metric_loss = proxy_anchor_loss(emb,lab)
            loss = 10*sim_loss + metric_loss

            total_loss += loss.item()
            total_sim_loss += 10*sim_loss.item()
            total_metric_loss += metric_loss.item()
            
 
            opt_mod.zero_grad()
            loss.backward(retain_graph=True)  
            opt_mod.step()
            
       
            opt_prox.zero_grad()
            metric_loss.backward()
            opt_prox.step()

            proxy_anchor_loss.update_momentum_proxies()
           
            
        
      
        avg_loss = total_loss / len(train_loader)
        epoch_losses.append(avg_loss)
        avg_sim_loss = total_sim_loss/len(train_loader)
        avg_metric_loss = total_metric_loss/len(train_loader)
     
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_loss:.14f}, Best Loss: {best_loss:.14f},Sim loss:{avg_sim_loss:.14f},Metric loss:{avg_metric_loss:.14f}")
        
      
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch
          
            early_stop_counter = 0  
        else:
            early_stop_counter += 1  

        if early_stop_counter >= patience:
            print(f"Early stopping triggered. No improvement for {patience} consecutive epochs.")
            break

        scheduler_mod.step()
        scheduler_prox.step()

    return epoch_losses




warnings.filterwarnings("ignore", message="input's size at dim=1 does not match num_features")

epochs = 50
device = 'cuda'
model.to(device)
lr = 0.00025
proxy_anchor_loss = ProxyAnchorLoss(num_classes = len(set(a_train)),embedding_dim = X_train.shape[1])
train_dataset = CustomDataset(X_train,a_train)
train_loader = DataLoader(train_dataset)
opt_mod = torch.optim.Adam(model.parameters(), lr=0.001)  
opt_prox = torch.optim.Adam(proxy_anchor_loss.parameters(), lr=0.001)

scheduler_mod = torch.optim.lr_scheduler.ExponentialLR(opt_mod, gamma=0.97)
scheduler_prox = torch.optim.lr_scheduler.ExponentialLR(opt_prox, gamma=0.97)
classification_losses = train_supcon_model(model, train_loader,opt_mod,opt_prox, scheduler_mod,scheduler_prox, epochs=200)

nn_proxies = proxy_anchor_loss.momentum_proxies

closest_idxs = ((X_train[:, None, :] - nn_proxies[None, :, :])**2).sum(dim=2).argmin(dim=0)
closest_vectors = X_train[closest_idxs]

np.save('/data/prototypes.npy',closest_vectors)

#sample the prototypes from train data by closest vectors with proxies
#use the proxies in https://github.com/EoinKenny/Prototype-Wrapper-Network-ICLR23/tree/main, replacing the exisitng prototypes