import numpy as np
import pandas as pd
import random
import torch
import os
import copy
from datetime import datetime
from utils.config import *
from torch.optim import lr_scheduler

args = get_config()

def make_csv(data, columns, file_name):
    df = pd.DataFrame(data)
    df.columns = columns
    df.to_csv(file_name, mode='w', index = False)

# Choose the number of train list in fixed coefficient.
def random_parameter(beta_range, nu_range, rho_range, epsilon_range, theta_range, N_random):
    
    beta_min, beta_step, beta_max = beta_range
    nu_min, nu_step, nu_max = nu_range
    rho_min, rho_step, rho_max = rho_range
    epsilon_min, epsilon_step, epsilon_max = epsilon_range
    theta_min, theta_step, theta_max = theta_range
    
    beta_values = [beta_min + i * beta_step for i in range(int((beta_max - beta_min) / beta_step) + 1)]
    nu_values = [nu_min + i * nu_step for i in range(int((nu_max - nu_min) / nu_step) + 1)]
    rho_values = [rho_min + i * rho_step for i in range(int((rho_max - rho_min) / rho_step) + 1)]
    epsilon_values = [epsilon_min + i * epsilon_step for i in range(int((epsilon_max - epsilon_min) / epsilon_step) + 1)]
    theta_values = [theta_min + i * theta_step for i in range(int((theta_max - theta_min) / theta_step) + 1)]
    
    result = []
    for N in range(N_random):
        combination = (random.choice(beta_values), random.choice(nu_values), random.choice(rho_values), random.choice(epsilon_values), random.choice(theta_values))
        result.append(combination)
    
    return result

def random_parameter_2D(beta_range, beta_y_range, nu_range, nu_y_range, rho_range, epsilon_range, theta_range, N_random):
    
    beta_min, beta_step, beta_max = beta_range
    beta_y_min, beta_y_step, beta_y_max = beta_y_range
    nu_min, nu_step, nu_max = nu_range
    nu_y_min, nu_y_step, nu_y_max = nu_y_range
    rho_min, rho_step, rho_max = rho_range
    epsilon_min, epsilon_step, epsilon_max = epsilon_range
    theta_min, theta_step, theta_max = theta_range
    
    beta_values = [beta_min + i * beta_step for i in range(int((beta_max - beta_min) / beta_step) + 1)]
    beta_y_values = [beta_y_min + i * beta_y_step for i in range(int((beta_y_max - beta_y_min) / beta_y_step) + 1)]
    nu_values = [nu_min + i * nu_step for i in range(int((nu_max - nu_min) / nu_step) + 1)]
    nu_y_values = [nu_y_min + i * nu_y_step for i in range(int((nu_y_max - nu_y_min) / nu_y_step) + 1)]
    rho_values = [rho_min + i * rho_step for i in range(int((rho_max - rho_min) / rho_step) + 1)]
    epsilon_values = [epsilon_min + i * epsilon_step for i in range(int((epsilon_max - epsilon_min) / epsilon_step) + 1)]
    theta_values = [theta_min + i * theta_step for i in range(int((theta_max - theta_min) / theta_step) + 1)]
    
    result = []
    for N in range(N_random):
        combination = (random.choice(beta_values), random.choice(beta_y_values), random.choice(nu_values), random.choice(nu_y_values), 
                       random.choice(rho_values), random.choice(epsilon_values), random.choice(theta_values))
        result.append(combination)
    
    return result


def param_number(model):
    pp = 0
    for p in list(model.parameters()):
        if p.requires_grad == True:
            nn = 1
            for s in list(p.size()):
                nn = nn * s
            pp += nn
    return pp


def Gaussian_noise(train_data):
    mean_value = torch.mean(train_data[:, 2])
    std_dev = torch.abs(mean_value * (args.noise_rate / 100))
    noise = torch.normal(mean=0.0, std=std_dev, size=train_data[:, 2].size())
    train_data[:, 2] += noise
    return train_data

def salt_and_pepper_noise(train_data):
    prob = args.noise_rate / 100
    mask = torch.rand(train_data[:, 2].size())
    salt_mask = mask < (prob / 2)
    train_data[salt_mask, 2] = train_data[:, 2].max()
    pepper_mask = mask > (1 - prob / 2)
    train_data[pepper_mask, 2] = train_data[:, 2].min()
    return train_data

def uniform_noise(train_data):
    noise_range = args.noise_rate / 100
    uniform_noise = torch.empty(train_data[:, 2].size()).uniform_(-noise_range, noise_range)
    train_data[:, 2] += uniform_noise
    return train_data


def construct_scheduler(optimizer, args):
    if args.scheduler == "constant":
        scheduler_method = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0)
    elif args.scheduler == "warmup_linear":
        scheduler_method = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0 if epoch < int(args.epoch*0.1) else np.abs(args.epoch-epoch*args.lr_decay)/args.epoch)
    elif args.scheduler == "cosine":
        scheduler_method = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch, eta_min=args.lr * 0.01, last_epoch=-1)
    elif args.scheduler == "step":
        scheduler_method = lr_scheduler.StepLR(optimizer, step_size=args.step, gamma=args.lr_decay)
    elif args.scheduler == "OneCycleLR":
        scheduler_method = lr_scheduler.OneCycleLR(optimizer=optimizer, steps_per_epoch=args.N_train_shot//args.batch,
                                                    pct_start=0.1, epochs=args.epoch, max_lr=args.lr)
    
    return scheduler_method


def task_specifier(args):
    target_dict = {"2D_cdr"    : ["Seen coeff", "Seen coeff given noisy", "Inter coeff", "Extra coeff"],
                   "SWE"       : ["Seen dataset", "Unseen dataset"],
                   "CNSE"      : ["Unseen dataset"],
                   "Darcy01"   : ["Initial value"],}
    if args.system not in target_dict[args.data]:
        raise ValueError("Not Implemented Task")
    if args.data in ["2D_cdr", "SWE"]:
        task_str = " - Temporal extrapolation" if args.extrapolation else " - Spatiotemporal interpolation"
    else:
        task_str = " - Operator learning"
    return args.system + task_str
        
        
    
# Net operator.
# There are two possible modes: "save" and "load".
def net_operator_ours(net, device, data, mode, n_in=4):
    str_analytical = "analytical" if args.numerical else f"PINN_based_{args.PINN_based_prior_ratio}"
    str_grid = "_grid" if args.is_grid else ""
    SAVE_PATH = f"./parameter/{args.model}/{data}"
    SAVE_NAME = f"{str_analytical}{str_grid}_{args.nips}_{args.nhid}_{args.nlayers}_{args.nhead}_{args.num_freqs}_"+\
                f"{args.AGF_depth}_{args.dropout}_{args.seed}_{args.lr}_{args.scheduler}.pt"
    
    print(SAVE_NAME)
    if not os.path.isdir(SAVE_PATH): os.makedirs(SAVE_PATH)
    
    if mode == "load":
        LOAD_PATH = SAVE_PATH + "/" + SAVE_NAME
        if os.path.isfile(LOAD_PATH):
            net.load_state_dict(torch.load(LOAD_PATH, map_location=device, weights_only=True))
            print("The net is loaded.")
        else:
            print("This is a new net.")
    elif mode == "finetuning":
        LOAD_PATH = SAVE_PATH + "/" + f"{str_analytical}_Pretrained_weight.pt"
        if os.path.isfile(LOAD_PATH):
            weight = torch.load(LOAD_PATH, map_location=device, weights_only=True)
            del weight['_orig_mod.decoder.2.bias'], weight['_orig_mod.decoder.2.weight']
            if n_in==4:
                del weight['_orig_mod.solution_encoder.0.bias'], weight['_orig_mod.solution_encoder.0.weight'],\
                    weight['_orig_mod.domain_encoder1.0.bias'], weight['_orig_mod.domain_encoder1.0.weight'], \
                    weight['_orig_mod.domain_encoder2.0.bias'], weight['_orig_mod.domain_encoder2.0.weight']
            else:
                weight_key_list = list(weight.keys())
                for key in weight_key_list:
                    if ("solution_encoder" in key) or ("domain_encoder" in key) :
                        del weight[key]
                        
            net.load_state_dict(weight, strict=False)
            print("The pretrained weight is loaded.")
        else:
            raise ValueError("No such pretrained weight")
    elif mode == "save":
        torch.save(net.state_dict(), SAVE_PATH + "/" + SAVE_NAME)
    else:
        print("Net operator does not work.")
    return net



def save_results(args, results):
    dir = f"./results/{args.model}"
    if not os.path.isdir(dir): os.makedirs(dir)
    results_path = dir + f"/{args.data}.csv"
    if os.path.exists(results_path):
        df_results = pd.read_csv(results_path)
    else:
        columns = ['timestamp', 'model_id', "system", "extrapolation", "numerical",         # basic for experiments
                   "seed", "lr", "lr decay", "scheduler", "total training", "max epoch",    # training setting
                   "hidden dim", "layer num", "transformer dim", "head num",                # basic Hyperparameter
                   "freqs num", "AGF depth", "alpha", "beta", "fixI",                       # additional Hyperparameter
                   "MSE", "L1 abs", "L2 rel", "L inf rel"]
        df_results = pd.DataFrame(columns=columns)

    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    addt_row = [timestamp, args.model, args.system, bool(args.extrapolation), bool(args.numerical),
                args.seed, args.lr, args.lr_decay, args.scheduler, bool(args.total), args.epoch,
                args.nhid, args.nlayers, args.nips, args.nhead,
                args.num_freqs, args.AGF_depth, args.alpha, args.beta, args.fixI,
                results["MSE"], results["L1_abs"], results["L2_rel"], results["Linf_rel"]]
    df_results.loc[len(df_results)] = addt_row
    df_results.to_csv(results_path, index=False)
    print("Save experiment results in CSV file!")